Spaces:
Sleeping
Sleeping
| # baseline/baseline_convnext.py | |
| from pathlib import Path | |
| import torch | |
| import pandas as pd | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import ConvNextV2ForImageClassification | |
| ROOT_DIR = Path(__file__).resolve().parent.parent | |
| BASELINE_DIR = Path(__file__).resolve().parent | |
| LIST_DIR = ROOT_DIR / "list" | |
| MODEL_PATH = BASELINE_DIR / "herbarium_convnext_v2_base.pth" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # species list | |
| species_df = pd.read_csv( | |
| LIST_DIR / "species_list.txt", | |
| sep=";", | |
| header=None, | |
| names=["class_id", "species_name"], | |
| ) | |
| class_names = list(species_df["species_name"]) | |
| num_labels = len(class_names) | |
| data_transforms = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]), | |
| ]) | |
| def _load_model(): | |
| model = ConvNextV2ForImageClassification.from_pretrained( | |
| "facebook/convnextv2-base-22k-224", | |
| num_labels=num_labels, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| if MODEL_PATH.is_file(): | |
| state = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.load_state_dict(state) | |
| else: | |
| print(f"[convnext] WARNING: {MODEL_PATH} not found, using HF weights only.") | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| convnext_model = _load_model() | |
| def predict_convnext(image: Image.Image): | |
| if image is None: | |
| return "Please upload an image." | |
| x = data_transforms(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = convnext_model(x).logits | |
| prob = torch.softmax(logits, dim=1)[0] | |
| top5_prob, top5_idx = torch.topk(prob, 5) | |
| return {class_names[i]: float(p) | |
| for i, p in zip(top5_idx.cpu().numpy(), top5_prob.cpu().numpy())} | |