File size: 11,006 Bytes
ef3d1e2
a9dfac0
 
 
 
 
 
 
ef3d1e2
a9dfac0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eae49fb
 
 
a9dfac0
eae49fb
 
 
a9dfac0
eae49fb
 
 
a9dfac0
ef3d1e2
a9dfac0
ef3d1e2
a9dfac0
ef3d1e2
a9dfac0
eae49fb
a9dfac0
 
eae49fb
a9dfac0
ef3d1e2
a9dfac0
eae49fb
a9dfac0
ef3d1e2
a9dfac0
ef3d1e2
a9dfac0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef3d1e2
 
 
 
 
 
 
 
eae49fb
ef3d1e2
a9dfac0
ef3d1e2
eae49fb
ef3d1e2
a9dfac0
ef3d1e2
 
 
5a15f82
ef3d1e2
eae49fb
ef3d1e2
 
a9dfac0
ef3d1e2
eae49fb
ef3d1e2
 
 
eae49fb
 
 
ef3d1e2
a9dfac0
ef3d1e2
eae49fb
a9dfac0
ef3d1e2
 
a9dfac0
 
 
 
 
 
 
ef3d1e2
 
 
 
 
a9dfac0
ef3d1e2
 
a9dfac0
ef3d1e2
 
eae49fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import gradio as gr
import os
import re
import pickle
import torch
import requests
from torchvision import transforms
from huggingface_hub import list_repo_files, hf_hub_download

# --- CONFIGURATION ---

# 1. Dataset Config
DATASET_ID = "FrAnKu34t23/Herbarium_Field"
DATASET_URL_BASE = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/train/herbarium/"
SPECIES_LIST_URL = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/list/species_list.txt"

# 2. Model Repo Config
MODEL_REPO_ID = "FrAnKu34t23/ensemble_models_plant" 
INDEX_FILENAME = "herbarium_index.pkl"

# Global Variables
REFERENCE_IMAGE_MAP = {} # Fallback (Class ID -> Image Filename)
NAME_TO_ID_MAP = {}      # Lookup (Species Name -> Class ID)
VECTOR_INDEX = None      # Smart Search Index
FEATURE_EXTRACTOR = None # DINOv2 model
TRANSFORM = None         # Image transforms

# --- SETUP: Load Resources ---
def load_resources():
    global VECTOR_INDEX, FEATURE_EXTRACTOR, TRANSFORM
    
    print("πŸš€ App starting... Initializing resources.")

    # 1. Load Name-to-ID Map (Crucial if models output only names)
    load_species_mapping()

    # 2. Download and Load Visual Search Index
    try:
        print(f"⬇️ Downloading {INDEX_FILENAME} from {MODEL_REPO_ID}...")
        index_path = hf_hub_download(
            repo_id=MODEL_REPO_ID, 
            filename=INDEX_FILENAME, 
            repo_type="model" 
        )
        print(f"βœ… Downloaded index. Loading pickle...")
        with open(index_path, "rb") as f:
            VECTOR_INDEX = pickle.load(f)
        
        # Load DINOv2
        print("⬇️ Loading DINOv2 (Retrieval Engine)...")
        FEATURE_EXTRACTOR = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
        FEATURE_EXTRACTOR.eval()
        
        TRANSFORM = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        print("πŸš€ Smart Search Ready!")
        
    except Exception as e:
        print(f"⚠️ Smart Search initialization failed: {e}")
        VECTOR_INDEX = None

    # 3. Build Fallback Map
    build_fallback_map()

def load_species_mapping():
    global NAME_TO_ID_MAP
    print("⬇️ Downloading species_list.txt for Name mapping...")
    try:
        # Fetch the text file from the dataset
        response = requests.get(SPECIES_LIST_URL)
        if response.status_code == 200:
            lines = response.text.splitlines()
            count = 0
            for line in lines:
                # Assuming format: "ClassID;SpeciesName" or "ClassID SpeciesName"
                # Adjust splitting based on your actual file format
                parts = re.split(r'[;\t,]', line) 
                if len(parts) >= 2:
                    # Try to identify which part is the ID (digits) and which is the Name
                    part1 = parts[0].strip()
                    part2 = parts[1].strip()
                    
                    if part1.isdigit():
                        c_id, c_name = part1, part2
                    else:
                        c_id, c_name = part2, part1 # Swap if ID is second
                    
                    # Store mapping: Name -> ID
                    # Normalize name (lowercase) for easier matching
                    NAME_TO_ID_MAP[c_name.lower()] = c_id
                    count += 1
            print(f"βœ… Loaded {count} species names into mapping.")
        else:
            print(f"⚠️ Failed to download species list. Status: {response.status_code}")
    except Exception as e:
        print(f"⚠️ Error loading species list: {e}")

def build_fallback_map():
    global REFERENCE_IMAGE_MAP
    try:
        print(f"πŸ”„ Scanning dataset {DATASET_ID} for fallback map...")
        all_files = list_repo_files(repo_id=DATASET_ID, repo_type="dataset")
        
        # Look for images in: train/herbarium/{class_id}/{filename}
        image_files = [f for f in all_files if f.startswith("train/herbarium/") and f.lower().endswith(('.jpg', '.png'))]
        
        for file_path in image_files:
            parts = file_path.split("/")
            if len(parts) >= 4:
                class_id = parts[2]
                filename = parts[3]
                if class_id not in REFERENCE_IMAGE_MAP:
                    REFERENCE_IMAGE_MAP[class_id] = filename
        print(f"βœ… Fallback map built for {len(REFERENCE_IMAGE_MAP)} classes.")
    except Exception as e:
        print(f"⚠️ Error scanning dataset: {e}")

# Load resources on startup
load_resources()

# --- Logic: ID Extraction & Search ---
def get_class_id_from_prediction(class_prediction):
    """
    Extracts Class ID from various formats, including pure Name lookups.
    """
    if not class_prediction: return None
    prediction_str = str(class_prediction).strip()

    # 1. Check for explicit ID in string (e.g. "Name (12345)")
    match = re.search(r'\((\d+)\)', prediction_str)
    if match: return match.group(1)

    # 2. Check if the string IS the ID (e.g. "12345")
    if prediction_str.isdigit():
        return prediction_str

    # 3. Check for "ID - Name" format
    match_start = re.match(r'^(\d+)\s+', prediction_str)
    if match_start: return match_start.group(1)
    
    # 4. NAME LOOKUP (New Feature)
    # If no numbers found, assume it's a name and look it up
    clean_name = prediction_str.lower().strip()
    if clean_name in NAME_TO_ID_MAP:
        return NAME_TO_ID_MAP[clean_name]
    
    return None

def find_most_similar_herbarium_sheet(class_prediction, input_pil_image):
    class_id = get_class_id_from_prediction(class_prediction)
    
    if not class_id:
        print(f"⚠️ Could not resolve Class ID for: '{class_prediction}'")
        return None

    # Strategy A: Visual Similarity (Vectors)
    if VECTOR_INDEX and FEATURE_EXTRACTOR and input_pil_image and class_id in VECTOR_INDEX:
        try:
            img_tensor = TRANSFORM(input_pil_image).unsqueeze(0)
            with torch.no_grad():
                input_vec = FEATURE_EXTRACTOR(img_tensor)
                input_vec = torch.nn.functional.normalize(input_vec, p=2, dim=1)
            
            candidates = VECTOR_INDEX[class_id]
            best_score = -1.0
            best_filename = None
            
            for item in candidates:
                score = torch.mm(input_vec, item["vector"].T).item()
                if score > best_score:
                    best_score = score
                    best_filename = item["filename"]
            
            if best_filename:
                return f"{DATASET_URL_BASE}{class_id}/{best_filename}"
        except Exception as e:
            print(f"⚠️ Search failed: {e}")

    # Strategy B: Fallback
    filename = REFERENCE_IMAGE_MAP.get(class_id)
    if filename:
        return f"{DATASET_URL_BASE}{class_id}/{filename}"
    
    return None

# --- Import User Models ---
try:
    from baseline.baseline_convnext import predict_convnext
except ImportError:
    def predict_convnext(image): return {"Error: ConvNeXt missing": 0.0}
try:
    from baseline.baseline_infer import predict_baseline
except ImportError:
    def predict_baseline(image): return {"Error: Baseline missing": 0.0}
try:
    from new_approach.spa_ensemble import predict_spa
except ImportError:
    def predict_spa(image): return {"Error: SPA missing": 0.0}

def predict_placeholder_2(image): return {"Model 4 Not Available": 0.0}

# --- Main App Logic ---
def predict(model_choice, image):
    if image is None: return None, None

    # STEP 1: CLASSIFICATION
    predictions = {}
    if model_choice == "Herbarium Species Classifier (ConvNeXT)":
        predictions = predict_convnext(image)
    elif model_choice == "Baseline (DINOv2 + LogReg)":
        predictions = predict_baseline(image)
    elif model_choice == "SPA Ensemble (New Approach)":
        predictions = predict_spa(image)
    elif model_choice == "Future Model 2 (Placeholder)":
        predictions = predict_placeholder_2(image)
    else:
        predictions = {"Invalid model": 0.0}

    # Handle case where model returns a String instead of Dict
    top_class_str = None
    if isinstance(predictions, dict) and predictions:
        top_class_str = max(predictions, key=predictions.get)
    elif isinstance(predictions, str):
        top_class_str = predictions

    # STEP 2: RETRIEVAL
    reference_image_url = None
    if top_class_str and "Error" not in top_class_str and "Please" not in top_class_str:
        try:
            reference_image_url = find_most_similar_herbarium_sheet(top_class_str, image)
        except Exception as e:
            print(f"Error in retrieval: {e}")

    return predictions, reference_image_url

# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
    with gr.Column(elem_id="app-wrapper"):
        gr.Markdown(
            """
            <div id="app-header">
              <h1>🌿 Plant Species Classification</h1>
              <h3>AML Group Project – Group 8</h3>
            </div>
            """, elem_id="app-header"
        )
        
        with gr.Row(elem_id="main-card"):
            with gr.Column(scale=1):
                model_selector = gr.Dropdown(
                    label="Select model",
                    choices=[
                        "Herbarium Species Classifier (ConvNeXT)",
                        "Baseline (DINOv2 + LogReg)",
                        "SPA Ensemble (New Approach)",
                        "Future Model 2 (Placeholder)",
                    ],
                    value="SPA Ensemble (New Approach)", 
                )
                
                gr.Markdown(
                    """
                    <div id="model-help">
                      <b>Herbarium Classifier</b> – ConvNeXtV2 CNN.<br>
                      <b>Baseline</b> – Simple DINOv2 + LogReg.<br>
                      <b>SPA Ensemble</b> – <i>(New)</i> DINOv2 + BioCLIP-2 + Handcrafted features.
                    </div>
                    """, elem_id="model-help"
                )
                
                image_input = gr.Image(type="pil", label="Upload plant image")
                submit_button = gr.Button("Classify 🌱", variant="primary")

            with gr.Column(scale=1):
                output_label = gr.Label(label="Top 5 predictions", num_top_classes=5)
                herbarium_output = gr.Image(
                    label="Matched Herbarium Specimen (Visual Reference)",
                    show_label=True,
                    interactive=False, 
                    height=300
                )

        submit_button.click(
            fn=predict,
            inputs=[model_selector, image_input],
            outputs=[output_label, herbarium_output],
        )

        gr.Markdown("Built for the AML course – Group 8", elem_id="footer")

if __name__ == "__main__":
    demo.launch()