import gradio as gr import os import re import pickle import torch import requests from torchvision import transforms from huggingface_hub import list_repo_files, hf_hub_download # --- CONFIGURATION --- # 1. Dataset Config DATASET_ID = "FrAnKu34t23/Herbarium_Field" DATASET_URL_BASE = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/train/herbarium/" SPECIES_LIST_URL = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/list/species_list.txt" # 2. Model Repo Config MODEL_REPO_ID = "FrAnKu34t23/ensemble_models_plant" INDEX_FILENAME = "herbarium_index.pkl" # Global Variables REFERENCE_IMAGE_MAP = {} # Fallback (Class ID -> Image Filename) NAME_TO_ID_MAP = {} # Lookup (Species Name -> Class ID) VECTOR_INDEX = None # Smart Search Index FEATURE_EXTRACTOR = None # DINOv2 model TRANSFORM = None # Image transforms # --- SETUP: Load Resources --- def load_resources(): global VECTOR_INDEX, FEATURE_EXTRACTOR, TRANSFORM print("🚀 App starting... Initializing resources.") # 1. Load Name-to-ID Map (Crucial if models output only names) load_species_mapping() # 2. Download and Load Visual Search Index try: print(f"⬇️ Downloading {INDEX_FILENAME} from {MODEL_REPO_ID}...") index_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=INDEX_FILENAME, repo_type="model" ) print(f"✅ Downloaded index. Loading pickle...") with open(index_path, "rb") as f: VECTOR_INDEX = pickle.load(f) # Load DINOv2 print("⬇️ Loading DINOv2 (Retrieval Engine)...") FEATURE_EXTRACTOR = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') FEATURE_EXTRACTOR.eval() TRANSFORM = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) print("🚀 Smart Search Ready!") except Exception as e: print(f"⚠️ Smart Search initialization failed: {e}") VECTOR_INDEX = None # 3. Build Fallback Map build_fallback_map() def load_species_mapping(): global NAME_TO_ID_MAP print("⬇️ Downloading species_list.txt for Name mapping...") try: # Fetch the text file from the dataset response = requests.get(SPECIES_LIST_URL) if response.status_code == 200: lines = response.text.splitlines() count = 0 for line in lines: # Assuming format: "ClassID;SpeciesName" or "ClassID SpeciesName" # Adjust splitting based on your actual file format parts = re.split(r'[;\t,]', line) if len(parts) >= 2: # Try to identify which part is the ID (digits) and which is the Name part1 = parts[0].strip() part2 = parts[1].strip() if part1.isdigit(): c_id, c_name = part1, part2 else: c_id, c_name = part2, part1 # Swap if ID is second # Store mapping: Name -> ID # Normalize name (lowercase) for easier matching NAME_TO_ID_MAP[c_name.lower()] = c_id count += 1 print(f"✅ Loaded {count} species names into mapping.") else: print(f"⚠️ Failed to download species list. Status: {response.status_code}") except Exception as e: print(f"⚠️ Error loading species list: {e}") def build_fallback_map(): global REFERENCE_IMAGE_MAP try: print(f"🔄 Scanning dataset {DATASET_ID} for fallback map...") all_files = list_repo_files(repo_id=DATASET_ID, repo_type="dataset") # Look for images in: train/herbarium/{class_id}/{filename} image_files = [f for f in all_files if f.startswith("train/herbarium/") and f.lower().endswith(('.jpg', '.png'))] for file_path in image_files: parts = file_path.split("/") if len(parts) >= 4: class_id = parts[2] filename = parts[3] if class_id not in REFERENCE_IMAGE_MAP: REFERENCE_IMAGE_MAP[class_id] = filename print(f"✅ Fallback map built for {len(REFERENCE_IMAGE_MAP)} classes.") except Exception as e: print(f"⚠️ Error scanning dataset: {e}") # Load resources on startup load_resources() # --- Logic: ID Extraction & Search --- def get_class_id_from_prediction(class_prediction): """ Extracts Class ID from various formats, including pure Name lookups. """ if not class_prediction: return None prediction_str = str(class_prediction).strip() # 1. Check for explicit ID in string (e.g. "Name (12345)") match = re.search(r'\((\d+)\)', prediction_str) if match: return match.group(1) # 2. Check if the string IS the ID (e.g. "12345") if prediction_str.isdigit(): return prediction_str # 3. Check for "ID - Name" format match_start = re.match(r'^(\d+)\s+', prediction_str) if match_start: return match_start.group(1) # 4. NAME LOOKUP (New Feature) # If no numbers found, assume it's a name and look it up clean_name = prediction_str.lower().strip() if clean_name in NAME_TO_ID_MAP: return NAME_TO_ID_MAP[clean_name] return None def find_most_similar_herbarium_sheet(class_prediction, input_pil_image): class_id = get_class_id_from_prediction(class_prediction) if not class_id: print(f"⚠️ Could not resolve Class ID for: '{class_prediction}'") return None # Strategy A: Visual Similarity (Vectors) if VECTOR_INDEX and FEATURE_EXTRACTOR and input_pil_image and class_id in VECTOR_INDEX: try: img_tensor = TRANSFORM(input_pil_image).unsqueeze(0) with torch.no_grad(): input_vec = FEATURE_EXTRACTOR(img_tensor) input_vec = torch.nn.functional.normalize(input_vec, p=2, dim=1) candidates = VECTOR_INDEX[class_id] best_score = -1.0 best_filename = None for item in candidates: score = torch.mm(input_vec, item["vector"].T).item() if score > best_score: best_score = score best_filename = item["filename"] if best_filename: return f"{DATASET_URL_BASE}{class_id}/{best_filename}" except Exception as e: print(f"⚠️ Search failed: {e}") # Strategy B: Fallback filename = REFERENCE_IMAGE_MAP.get(class_id) if filename: return f"{DATASET_URL_BASE}{class_id}/{filename}" return None # --- Import User Models --- try: from baseline.baseline_convnext import predict_convnext except ImportError: def predict_convnext(image): return {"Error: ConvNeXt missing": 0.0} try: from baseline.baseline_infer import predict_baseline except ImportError: def predict_baseline(image): return {"Error: Baseline missing": 0.0} try: from new_approach.spa_ensemble import predict_spa except ImportError: def predict_spa(image): return {"Error: SPA missing": 0.0} def predict_placeholder_2(image): return {"Model 4 Not Available": 0.0} # --- Main App Logic --- def predict(model_choice, image): if image is None: return None, None # STEP 1: CLASSIFICATION predictions = {} if model_choice == "Herbarium Species Classifier (ConvNeXT)": predictions = predict_convnext(image) elif model_choice == "Baseline (DINOv2 + LogReg)": predictions = predict_baseline(image) elif model_choice == "SPA Ensemble (New Approach)": predictions = predict_spa(image) elif model_choice == "Future Model 2 (Placeholder)": predictions = predict_placeholder_2(image) else: predictions = {"Invalid model": 0.0} # Handle case where model returns a String instead of Dict top_class_str = None if isinstance(predictions, dict) and predictions: top_class_str = max(predictions, key=predictions.get) elif isinstance(predictions, str): top_class_str = predictions # STEP 2: RETRIEVAL reference_image_url = None if top_class_str and "Error" not in top_class_str and "Please" not in top_class_str: try: reference_image_url = find_most_similar_herbarium_sheet(top_class_str, image) except Exception as e: print(f"Error in retrieval: {e}") return predictions, reference_image_url # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo: with gr.Column(elem_id="app-wrapper"): gr.Markdown( """