FrAnKu34t23 commited on
Commit
a9dfac0
·
verified ·
1 Parent(s): 5a15f82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -74
app.py CHANGED
@@ -1,83 +1,255 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- # --- 1. Import Existing Baselines ---
4
- # Wrapped in try-except so the app doesn't crash if files are temporarily missing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  try:
6
  from baseline.baseline_convnext import predict_convnext
7
  except ImportError:
8
- def predict_convnext(image): return {"Error": "ConvNeXt module missing"}
9
-
10
  try:
11
  from baseline.baseline_infer import predict_baseline
12
  except ImportError:
13
- def predict_baseline(image): return {"Error": "Baseline module missing"}
14
-
15
- # --- 2. Import NEW SPA Approach ---
16
- # This imports the function from: new_approach/spa_ensemble.py
17
  try:
18
  from new_approach.spa_ensemble import predict_spa
19
  except ImportError:
20
- def predict_spa(image): return {"Error": "SPA module missing. Check 'new_approach' folder."}
21
 
 
22
 
23
- # --- Placeholder models (for future extensions) ---
24
- def predict_placeholder_2(image):
25
- if image is None:
26
- return "Please upload an image."
27
- return "Model 4 is not available yet. Please check back later."
28
-
29
- # --- Main Prediction Logic ---
30
  def predict(model_choice, image):
31
- if image is None: return None
32
 
 
 
33
  if model_choice == "Herbarium Species Classifier (ConvNeXT)":
34
- # Luna's ConvNeXt mix-stream CNN baseline
35
- return predict_convnext(image)
36
-
37
  elif model_choice == "Baseline (DINOv2 + LogReg)":
38
- # Islam's baseline
39
- return predict_baseline(image)
40
-
41
  elif model_choice == "SPA Ensemble (New Approach)":
42
- # New approach: DINOv2 + BioCLIP + Handcrafted + SPA
43
- return predict_spa(image)
44
-
45
  elif model_choice == "Future Model 2 (Placeholder)":
46
- return predict_placeholder_2(image)
47
-
48
  else:
49
- return "Invalid model selected."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # --- Gradio Interface ---
52
  with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
53
  with gr.Column(elem_id="app-wrapper"):
54
- # Header
55
  gr.Markdown(
56
  """
57
  <div id="app-header">
58
  <h1>🌿 Plant Species Classification</h1>
59
  <h3>AML Group Project – Group 8</h3>
60
  </div>
61
- """,
62
- elem_id="app-header",
63
  )
64
 
65
- # Badges row
66
- gr.Markdown(
67
- """
68
- <div id="badge-row">
69
- <span class="badge">Herbarium + Field images (ConvNeXT)</span>
70
- <span class="badge">ConvNeXtV2</span>
71
- <span class="badge">SPA Ensemble</span>
72
- </div>
73
- """,
74
- elem_id="badge-row",
75
- )
76
-
77
- # Main card
78
  with gr.Row(elem_id="main-card"):
79
- # Left side: model + image
80
- with gr.Column(scale=1, elem_id="left-panel"):
81
  model_selector = gr.Dropdown(
82
  label="Select model",
83
  choices=[
@@ -86,7 +258,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
86
  "SPA Ensemble (New Approach)",
87
  "Future Model 2 (Placeholder)",
88
  ],
89
- value="SPA Ensemble (New Approach)", # Default to your new model
90
  )
91
 
92
  gr.Markdown(
@@ -96,43 +268,28 @@ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
96
  <b>Baseline</b> – Simple DINOv2 + LogReg.<br>
97
  <b>SPA Ensemble</b> – <i>(New)</i> DINOv2 + BioCLIP-2 + Handcrafted features.
98
  </div>
99
- """,
100
- elem_id="model-help",
101
- )
102
-
103
- image_input = gr.Image(
104
- type="pil",
105
- label="Upload plant image",
106
  )
107
 
 
108
  submit_button = gr.Button("Classify 🌱", variant="primary")
109
 
110
- # Right side: predictions
111
- with gr.Column(scale=1, elem_id="right-panel"):
112
- output_label = gr.Label(
113
- label="Top 5 predictions",
114
- num_top_classes=5,
 
 
115
  )
116
 
117
  submit_button.click(
118
  fn=predict,
119
  inputs=[model_selector, image_input],
120
- outputs=output_label,
121
- )
122
-
123
- # Optional examples
124
- gr.Examples(
125
- examples=[],
126
- inputs=image_input,
127
- outputs=output_label,
128
- fn=lambda img: predict("SPA Ensemble (New Approach)", img),
129
- cache_examples=False,
130
  )
131
 
132
- gr.Markdown(
133
- "Built for the AML course – compare CNN vs. DINOv2 feature-extractor baselines with the new approaches to address cross-domain plant identification.",
134
- elem_id="footer",
135
- )
136
 
137
  if __name__ == "__main__":
138
  demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import re
4
+ import pickle
5
+ import torch
6
+ import requests
7
+ from torchvision import transforms
8
+ from huggingface_hub import list_repo_files, hf_hub_download
9
 
10
+ # --- CONFIGURATION ---
11
+
12
+ # 1. Dataset Config
13
+ DATASET_ID = "FrAnKu34t23/Herbarium_Field"
14
+ DATASET_URL_BASE = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/train/herbarium/"
15
+ SPECIES_LIST_URL = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/list/species_list.txt"
16
+
17
+ # 2. Model Repo Config
18
+ MODEL_REPO_ID = "FrAnKu34t23/ensemble_models_plant"
19
+ INDEX_FILENAME = "herbarium_index.pkl"
20
+
21
+ # Global Variables
22
+ REFERENCE_IMAGE_MAP = {} # Fallback (Class ID -> Image Filename)
23
+ NAME_TO_ID_MAP = {} # Lookup (Species Name -> Class ID)
24
+ VECTOR_INDEX = None # Smart Search Index
25
+ FEATURE_EXTRACTOR = None # DINOv2 model
26
+ TRANSFORM = None # Image transforms
27
+
28
+ # --- SETUP: Load Resources ---
29
+ def load_resources():
30
+ global VECTOR_INDEX, FEATURE_EXTRACTOR, TRANSFORM
31
+
32
+ print("🚀 App starting... Initializing resources.")
33
+
34
+ # 1. Load Name-to-ID Map (Crucial if models output only names)
35
+ load_species_mapping()
36
+
37
+ # 2. Download and Load Visual Search Index
38
+ try:
39
+ print(f"⬇️ Downloading {INDEX_FILENAME} from {MODEL_REPO_ID}...")
40
+ index_path = hf_hub_download(
41
+ repo_id=MODEL_REPO_ID,
42
+ filename=INDEX_FILENAME,
43
+ repo_type="model"
44
+ )
45
+ print(f"✅ Downloaded index. Loading pickle...")
46
+ with open(index_path, "rb") as f:
47
+ VECTOR_INDEX = pickle.load(f)
48
+
49
+ # Load DINOv2
50
+ print("⬇️ Loading DINOv2 (Retrieval Engine)...")
51
+ FEATURE_EXTRACTOR = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
52
+ FEATURE_EXTRACTOR.eval()
53
+
54
+ TRANSFORM = transforms.Compose([
55
+ transforms.Resize((224, 224)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
58
+ ])
59
+ print("🚀 Smart Search Ready!")
60
+
61
+ except Exception as e:
62
+ print(f"⚠️ Smart Search initialization failed: {e}")
63
+ VECTOR_INDEX = None
64
+
65
+ # 3. Build Fallback Map
66
+ build_fallback_map()
67
+
68
+ def load_species_mapping():
69
+ global NAME_TO_ID_MAP
70
+ print("⬇️ Downloading species_list.txt for Name mapping...")
71
+ try:
72
+ # Fetch the text file from the dataset
73
+ response = requests.get(SPECIES_LIST_URL)
74
+ if response.status_code == 200:
75
+ lines = response.text.splitlines()
76
+ count = 0
77
+ for line in lines:
78
+ # Assuming format: "ClassID;SpeciesName" or "ClassID SpeciesName"
79
+ # Adjust splitting based on your actual file format
80
+ parts = re.split(r'[;\t,]', line)
81
+ if len(parts) >= 2:
82
+ # Try to identify which part is the ID (digits) and which is the Name
83
+ part1 = parts[0].strip()
84
+ part2 = parts[1].strip()
85
+
86
+ if part1.isdigit():
87
+ c_id, c_name = part1, part2
88
+ else:
89
+ c_id, c_name = part2, part1 # Swap if ID is second
90
+
91
+ # Store mapping: Name -> ID
92
+ # Normalize name (lowercase) for easier matching
93
+ NAME_TO_ID_MAP[c_name.lower()] = c_id
94
+ count += 1
95
+ print(f"✅ Loaded {count} species names into mapping.")
96
+ else:
97
+ print(f"⚠️ Failed to download species list. Status: {response.status_code}")
98
+ except Exception as e:
99
+ print(f"⚠️ Error loading species list: {e}")
100
+
101
+ def build_fallback_map():
102
+ global REFERENCE_IMAGE_MAP
103
+ try:
104
+ print(f"🔄 Scanning dataset {DATASET_ID} for fallback map...")
105
+ all_files = list_repo_files(repo_id=DATASET_ID, repo_type="dataset")
106
+
107
+ # Look for images in: train/herbarium/{class_id}/{filename}
108
+ image_files = [f for f in all_files if f.startswith("train/herbarium/") and f.lower().endswith(('.jpg', '.png'))]
109
+
110
+ for file_path in image_files:
111
+ parts = file_path.split("/")
112
+ if len(parts) >= 4:
113
+ class_id = parts[2]
114
+ filename = parts[3]
115
+ if class_id not in REFERENCE_IMAGE_MAP:
116
+ REFERENCE_IMAGE_MAP[class_id] = filename
117
+ print(f"✅ Fallback map built for {len(REFERENCE_IMAGE_MAP)} classes.")
118
+ except Exception as e:
119
+ print(f"⚠️ Error scanning dataset: {e}")
120
+
121
+ # Load resources on startup
122
+ load_resources()
123
+
124
+ # --- Logic: ID Extraction & Search ---
125
+ def get_class_id_from_prediction(class_prediction):
126
+ """
127
+ Extracts Class ID from various formats, including pure Name lookups.
128
+ """
129
+ if not class_prediction: return None
130
+ prediction_str = str(class_prediction).strip()
131
+
132
+ # 1. Check for explicit ID in string (e.g. "Name (12345)")
133
+ match = re.search(r'\((\d+)\)', prediction_str)
134
+ if match: return match.group(1)
135
+
136
+ # 2. Check if the string IS the ID (e.g. "12345")
137
+ if prediction_str.isdigit():
138
+ return prediction_str
139
+
140
+ # 3. Check for "ID - Name" format
141
+ match_start = re.match(r'^(\d+)\s+', prediction_str)
142
+ if match_start: return match_start.group(1)
143
+
144
+ # 4. NAME LOOKUP (New Feature)
145
+ # If no numbers found, assume it's a name and look it up
146
+ clean_name = prediction_str.lower().strip()
147
+ if clean_name in NAME_TO_ID_MAP:
148
+ return NAME_TO_ID_MAP[clean_name]
149
+
150
+ return None
151
+
152
+ def find_most_similar_herbarium_sheet(class_prediction, input_pil_image):
153
+ class_id = get_class_id_from_prediction(class_prediction)
154
+
155
+ if not class_id:
156
+ print(f"⚠️ Could not resolve Class ID for: '{class_prediction}'")
157
+ return None
158
+
159
+ # Strategy A: Visual Similarity (Vectors)
160
+ if VECTOR_INDEX and FEATURE_EXTRACTOR and input_pil_image and class_id in VECTOR_INDEX:
161
+ try:
162
+ img_tensor = TRANSFORM(input_pil_image).unsqueeze(0)
163
+ with torch.no_grad():
164
+ input_vec = FEATURE_EXTRACTOR(img_tensor)
165
+ input_vec = torch.nn.functional.normalize(input_vec, p=2, dim=1)
166
+
167
+ candidates = VECTOR_INDEX[class_id]
168
+ best_score = -1.0
169
+ best_filename = None
170
+
171
+ for item in candidates:
172
+ score = torch.mm(input_vec, item["vector"].T).item()
173
+ if score > best_score:
174
+ best_score = score
175
+ best_filename = item["filename"]
176
+
177
+ if best_filename:
178
+ return f"{DATASET_URL_BASE}{class_id}/{best_filename}"
179
+ except Exception as e:
180
+ print(f"⚠️ Search failed: {e}")
181
+
182
+ # Strategy B: Fallback
183
+ filename = REFERENCE_IMAGE_MAP.get(class_id)
184
+ if filename:
185
+ return f"{DATASET_URL_BASE}{class_id}/{filename}"
186
+
187
+ return None
188
+
189
+ # --- Import User Models ---
190
  try:
191
  from baseline.baseline_convnext import predict_convnext
192
  except ImportError:
193
+ def predict_convnext(image): return {"Error: ConvNeXt missing": 0.0}
 
194
  try:
195
  from baseline.baseline_infer import predict_baseline
196
  except ImportError:
197
+ def predict_baseline(image): return {"Error: Baseline missing": 0.0}
 
 
 
198
  try:
199
  from new_approach.spa_ensemble import predict_spa
200
  except ImportError:
201
+ def predict_spa(image): return {"Error: SPA missing": 0.0}
202
 
203
+ def predict_placeholder_2(image): return {"Model 4 Not Available": 0.0}
204
 
205
+ # --- Main App Logic ---
 
 
 
 
 
 
206
  def predict(model_choice, image):
207
+ if image is None: return None, None
208
 
209
+ # STEP 1: CLASSIFICATION
210
+ predictions = {}
211
  if model_choice == "Herbarium Species Classifier (ConvNeXT)":
212
+ predictions = predict_convnext(image)
 
 
213
  elif model_choice == "Baseline (DINOv2 + LogReg)":
214
+ predictions = predict_baseline(image)
 
 
215
  elif model_choice == "SPA Ensemble (New Approach)":
216
+ predictions = predict_spa(image)
 
 
217
  elif model_choice == "Future Model 2 (Placeholder)":
218
+ predictions = predict_placeholder_2(image)
 
219
  else:
220
+ predictions = {"Invalid model": 0.0}
221
+
222
+ # Handle case where model returns a String instead of Dict
223
+ top_class_str = None
224
+ if isinstance(predictions, dict) and predictions:
225
+ top_class_str = max(predictions, key=predictions.get)
226
+ elif isinstance(predictions, str):
227
+ top_class_str = predictions
228
+
229
+ # STEP 2: RETRIEVAL
230
+ reference_image_url = None
231
+ if top_class_str and "Error" not in top_class_str and "Please" not in top_class_str:
232
+ try:
233
+ reference_image_url = find_most_similar_herbarium_sheet(top_class_str, image)
234
+ except Exception as e:
235
+ print(f"Error in retrieval: {e}")
236
+
237
+ return predictions, reference_image_url
238
 
239
  # --- Gradio Interface ---
240
  with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
241
  with gr.Column(elem_id="app-wrapper"):
 
242
  gr.Markdown(
243
  """
244
  <div id="app-header">
245
  <h1>🌿 Plant Species Classification</h1>
246
  <h3>AML Group Project – Group 8</h3>
247
  </div>
248
+ """, elem_id="app-header"
 
249
  )
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  with gr.Row(elem_id="main-card"):
252
+ with gr.Column(scale=1):
 
253
  model_selector = gr.Dropdown(
254
  label="Select model",
255
  choices=[
 
258
  "SPA Ensemble (New Approach)",
259
  "Future Model 2 (Placeholder)",
260
  ],
261
+ value="SPA Ensemble (New Approach)",
262
  )
263
 
264
  gr.Markdown(
 
268
  <b>Baseline</b> – Simple DINOv2 + LogReg.<br>
269
  <b>SPA Ensemble</b> – <i>(New)</i> DINOv2 + BioCLIP-2 + Handcrafted features.
270
  </div>
271
+ """, elem_id="model-help"
 
 
 
 
 
 
272
  )
273
 
274
+ image_input = gr.Image(type="pil", label="Upload plant image")
275
  submit_button = gr.Button("Classify 🌱", variant="primary")
276
 
277
+ with gr.Column(scale=1):
278
+ output_label = gr.Label(label="Top 5 predictions", num_top_classes=5)
279
+ herbarium_output = gr.Image(
280
+ label="Matched Herbarium Specimen (Visual Reference)",
281
+ show_label=True,
282
+ interactive=False,
283
+ height=300
284
  )
285
 
286
  submit_button.click(
287
  fn=predict,
288
  inputs=[model_selector, image_input],
289
+ outputs=[output_label, herbarium_output],
 
 
 
 
 
 
 
 
 
290
  )
291
 
292
+ gr.Markdown("Built for the AML course – Group 8", elem_id="footer")
 
 
 
293
 
294
  if __name__ == "__main__":
295
  demo.launch()