File size: 1,867 Bytes
ef3d1e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# baseline/baseline_convnext.py
from pathlib import Path
import torch
import pandas as pd
from PIL import Image
from torchvision import transforms
from transformers import ConvNextV2ForImageClassification

ROOT_DIR = Path(__file__).resolve().parent.parent
BASELINE_DIR = Path(__file__).resolve().parent
LIST_DIR = ROOT_DIR / "list"
MODEL_PATH = BASELINE_DIR / "herbarium_convnext_v2_base.pth"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# species list
species_df = pd.read_csv(
    LIST_DIR / "species_list.txt",
    sep=";",
    header=None,
    names=["class_id", "species_name"],
)
class_names = list(species_df["species_name"])
num_labels = len(class_names)

data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

def _load_model():
    model = ConvNextV2ForImageClassification.from_pretrained(
        "facebook/convnextv2-base-22k-224",
        num_labels=num_labels,
        ignore_mismatched_sizes=True,
    )
    if MODEL_PATH.is_file():
        state = torch.load(MODEL_PATH, map_location=DEVICE)
        model.load_state_dict(state)
    else:
        print(f"[convnext] WARNING: {MODEL_PATH} not found, using HF weights only.")
    model.to(DEVICE)
    model.eval()
    return model

convnext_model = _load_model()

def predict_convnext(image: Image.Image):
    if image is None:
        return "Please upload an image."
    x = data_transforms(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = convnext_model(x).logits
        prob = torch.softmax(logits, dim=1)[0]
        top5_prob, top5_idx = torch.topk(prob, 5)
    return {class_names[i]: float(p)
            for i, p in zip(top5_idx.cpu().numpy(), top5_prob.cpu().numpy())}