Spaces:
Sleeping
Sleeping
File size: 11,006 Bytes
ef3d1e2 a9dfac0 ef3d1e2 a9dfac0 eae49fb a9dfac0 eae49fb a9dfac0 eae49fb a9dfac0 ef3d1e2 a9dfac0 ef3d1e2 a9dfac0 ef3d1e2 a9dfac0 eae49fb a9dfac0 eae49fb a9dfac0 ef3d1e2 a9dfac0 eae49fb a9dfac0 ef3d1e2 a9dfac0 ef3d1e2 a9dfac0 ef3d1e2 eae49fb ef3d1e2 a9dfac0 ef3d1e2 eae49fb ef3d1e2 a9dfac0 ef3d1e2 5a15f82 ef3d1e2 eae49fb ef3d1e2 a9dfac0 ef3d1e2 eae49fb ef3d1e2 eae49fb ef3d1e2 a9dfac0 ef3d1e2 eae49fb a9dfac0 ef3d1e2 a9dfac0 ef3d1e2 a9dfac0 ef3d1e2 a9dfac0 ef3d1e2 eae49fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
import gradio as gr
import os
import re
import pickle
import torch
import requests
from torchvision import transforms
from huggingface_hub import list_repo_files, hf_hub_download
# --- CONFIGURATION ---
# 1. Dataset Config
DATASET_ID = "FrAnKu34t23/Herbarium_Field"
DATASET_URL_BASE = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/train/herbarium/"
SPECIES_LIST_URL = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/list/species_list.txt"
# 2. Model Repo Config
MODEL_REPO_ID = "FrAnKu34t23/ensemble_models_plant"
INDEX_FILENAME = "herbarium_index.pkl"
# Global Variables
REFERENCE_IMAGE_MAP = {} # Fallback (Class ID -> Image Filename)
NAME_TO_ID_MAP = {} # Lookup (Species Name -> Class ID)
VECTOR_INDEX = None # Smart Search Index
FEATURE_EXTRACTOR = None # DINOv2 model
TRANSFORM = None # Image transforms
# --- SETUP: Load Resources ---
def load_resources():
global VECTOR_INDEX, FEATURE_EXTRACTOR, TRANSFORM
print("π App starting... Initializing resources.")
# 1. Load Name-to-ID Map (Crucial if models output only names)
load_species_mapping()
# 2. Download and Load Visual Search Index
try:
print(f"β¬οΈ Downloading {INDEX_FILENAME} from {MODEL_REPO_ID}...")
index_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=INDEX_FILENAME,
repo_type="model"
)
print(f"β
Downloaded index. Loading pickle...")
with open(index_path, "rb") as f:
VECTOR_INDEX = pickle.load(f)
# Load DINOv2
print("β¬οΈ Loading DINOv2 (Retrieval Engine)...")
FEATURE_EXTRACTOR = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
FEATURE_EXTRACTOR.eval()
TRANSFORM = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
print("π Smart Search Ready!")
except Exception as e:
print(f"β οΈ Smart Search initialization failed: {e}")
VECTOR_INDEX = None
# 3. Build Fallback Map
build_fallback_map()
def load_species_mapping():
global NAME_TO_ID_MAP
print("β¬οΈ Downloading species_list.txt for Name mapping...")
try:
# Fetch the text file from the dataset
response = requests.get(SPECIES_LIST_URL)
if response.status_code == 200:
lines = response.text.splitlines()
count = 0
for line in lines:
# Assuming format: "ClassID;SpeciesName" or "ClassID SpeciesName"
# Adjust splitting based on your actual file format
parts = re.split(r'[;\t,]', line)
if len(parts) >= 2:
# Try to identify which part is the ID (digits) and which is the Name
part1 = parts[0].strip()
part2 = parts[1].strip()
if part1.isdigit():
c_id, c_name = part1, part2
else:
c_id, c_name = part2, part1 # Swap if ID is second
# Store mapping: Name -> ID
# Normalize name (lowercase) for easier matching
NAME_TO_ID_MAP[c_name.lower()] = c_id
count += 1
print(f"β
Loaded {count} species names into mapping.")
else:
print(f"β οΈ Failed to download species list. Status: {response.status_code}")
except Exception as e:
print(f"β οΈ Error loading species list: {e}")
def build_fallback_map():
global REFERENCE_IMAGE_MAP
try:
print(f"π Scanning dataset {DATASET_ID} for fallback map...")
all_files = list_repo_files(repo_id=DATASET_ID, repo_type="dataset")
# Look for images in: train/herbarium/{class_id}/{filename}
image_files = [f for f in all_files if f.startswith("train/herbarium/") and f.lower().endswith(('.jpg', '.png'))]
for file_path in image_files:
parts = file_path.split("/")
if len(parts) >= 4:
class_id = parts[2]
filename = parts[3]
if class_id not in REFERENCE_IMAGE_MAP:
REFERENCE_IMAGE_MAP[class_id] = filename
print(f"β
Fallback map built for {len(REFERENCE_IMAGE_MAP)} classes.")
except Exception as e:
print(f"β οΈ Error scanning dataset: {e}")
# Load resources on startup
load_resources()
# --- Logic: ID Extraction & Search ---
def get_class_id_from_prediction(class_prediction):
"""
Extracts Class ID from various formats, including pure Name lookups.
"""
if not class_prediction: return None
prediction_str = str(class_prediction).strip()
# 1. Check for explicit ID in string (e.g. "Name (12345)")
match = re.search(r'\((\d+)\)', prediction_str)
if match: return match.group(1)
# 2. Check if the string IS the ID (e.g. "12345")
if prediction_str.isdigit():
return prediction_str
# 3. Check for "ID - Name" format
match_start = re.match(r'^(\d+)\s+', prediction_str)
if match_start: return match_start.group(1)
# 4. NAME LOOKUP (New Feature)
# If no numbers found, assume it's a name and look it up
clean_name = prediction_str.lower().strip()
if clean_name in NAME_TO_ID_MAP:
return NAME_TO_ID_MAP[clean_name]
return None
def find_most_similar_herbarium_sheet(class_prediction, input_pil_image):
class_id = get_class_id_from_prediction(class_prediction)
if not class_id:
print(f"β οΈ Could not resolve Class ID for: '{class_prediction}'")
return None
# Strategy A: Visual Similarity (Vectors)
if VECTOR_INDEX and FEATURE_EXTRACTOR and input_pil_image and class_id in VECTOR_INDEX:
try:
img_tensor = TRANSFORM(input_pil_image).unsqueeze(0)
with torch.no_grad():
input_vec = FEATURE_EXTRACTOR(img_tensor)
input_vec = torch.nn.functional.normalize(input_vec, p=2, dim=1)
candidates = VECTOR_INDEX[class_id]
best_score = -1.0
best_filename = None
for item in candidates:
score = torch.mm(input_vec, item["vector"].T).item()
if score > best_score:
best_score = score
best_filename = item["filename"]
if best_filename:
return f"{DATASET_URL_BASE}{class_id}/{best_filename}"
except Exception as e:
print(f"β οΈ Search failed: {e}")
# Strategy B: Fallback
filename = REFERENCE_IMAGE_MAP.get(class_id)
if filename:
return f"{DATASET_URL_BASE}{class_id}/{filename}"
return None
# --- Import User Models ---
try:
from baseline.baseline_convnext import predict_convnext
except ImportError:
def predict_convnext(image): return {"Error: ConvNeXt missing": 0.0}
try:
from baseline.baseline_infer import predict_baseline
except ImportError:
def predict_baseline(image): return {"Error: Baseline missing": 0.0}
try:
from new_approach.spa_ensemble import predict_spa
except ImportError:
def predict_spa(image): return {"Error: SPA missing": 0.0}
def predict_placeholder_2(image): return {"Model 4 Not Available": 0.0}
# --- Main App Logic ---
def predict(model_choice, image):
if image is None: return None, None
# STEP 1: CLASSIFICATION
predictions = {}
if model_choice == "Herbarium Species Classifier (ConvNeXT)":
predictions = predict_convnext(image)
elif model_choice == "Baseline (DINOv2 + LogReg)":
predictions = predict_baseline(image)
elif model_choice == "SPA Ensemble (New Approach)":
predictions = predict_spa(image)
elif model_choice == "Future Model 2 (Placeholder)":
predictions = predict_placeholder_2(image)
else:
predictions = {"Invalid model": 0.0}
# Handle case where model returns a String instead of Dict
top_class_str = None
if isinstance(predictions, dict) and predictions:
top_class_str = max(predictions, key=predictions.get)
elif isinstance(predictions, str):
top_class_str = predictions
# STEP 2: RETRIEVAL
reference_image_url = None
if top_class_str and "Error" not in top_class_str and "Please" not in top_class_str:
try:
reference_image_url = find_most_similar_herbarium_sheet(top_class_str, image)
except Exception as e:
print(f"Error in retrieval: {e}")
return predictions, reference_image_url
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
with gr.Column(elem_id="app-wrapper"):
gr.Markdown(
"""
<div id="app-header">
<h1>πΏ Plant Species Classification</h1>
<h3>AML Group Project β Group 8</h3>
</div>
""", elem_id="app-header"
)
with gr.Row(elem_id="main-card"):
with gr.Column(scale=1):
model_selector = gr.Dropdown(
label="Select model",
choices=[
"Herbarium Species Classifier (ConvNeXT)",
"Baseline (DINOv2 + LogReg)",
"SPA Ensemble (New Approach)",
"Future Model 2 (Placeholder)",
],
value="SPA Ensemble (New Approach)",
)
gr.Markdown(
"""
<div id="model-help">
<b>Herbarium Classifier</b> β ConvNeXtV2 CNN.<br>
<b>Baseline</b> β Simple DINOv2 + LogReg.<br>
<b>SPA Ensemble</b> β <i>(New)</i> DINOv2 + BioCLIP-2 + Handcrafted features.
</div>
""", elem_id="model-help"
)
image_input = gr.Image(type="pil", label="Upload plant image")
submit_button = gr.Button("Classify π±", variant="primary")
with gr.Column(scale=1):
output_label = gr.Label(label="Top 5 predictions", num_top_classes=5)
herbarium_output = gr.Image(
label="Matched Herbarium Specimen (Visual Reference)",
show_label=True,
interactive=False,
height=300
)
submit_button.click(
fn=predict,
inputs=[model_selector, image_input],
outputs=[output_label, herbarium_output],
)
gr.Markdown("Built for the AML course β Group 8", elem_id="footer")
if __name__ == "__main__":
demo.launch() |