# baseline/baseline_convnext.py from pathlib import Path import torch import pandas as pd from PIL import Image from torchvision import transforms from transformers import ConvNextV2ForImageClassification ROOT_DIR = Path(__file__).resolve().parent.parent BASELINE_DIR = Path(__file__).resolve().parent LIST_DIR = ROOT_DIR / "list" MODEL_PATH = BASELINE_DIR / "herbarium_convnext_v2_base.pth" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # species list species_df = pd.read_csv( LIST_DIR / "species_list.txt", sep=";", header=None, names=["class_id", "species_name"], ) class_names = list(species_df["species_name"]) num_labels = len(class_names) data_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def _load_model(): model = ConvNextV2ForImageClassification.from_pretrained( "facebook/convnextv2-base-22k-224", num_labels=num_labels, ignore_mismatched_sizes=True, ) if MODEL_PATH.is_file(): state = torch.load(MODEL_PATH, map_location=DEVICE) model.load_state_dict(state) else: print(f"[convnext] WARNING: {MODEL_PATH} not found, using HF weights only.") model.to(DEVICE) model.eval() return model convnext_model = _load_model() def predict_convnext(image: Image.Image): if image is None: return "Please upload an image." x = data_transforms(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = convnext_model(x).logits prob = torch.softmax(logits, dim=1)[0] top5_prob, top5_idx = torch.topk(prob, 5) return {class_names[i]: float(p) for i, p in zip(top5_idx.cpu().numpy(), top5_prob.cpu().numpy())}