Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import re | |
| import pickle | |
| import torch | |
| import requests | |
| from torchvision import transforms | |
| from huggingface_hub import list_repo_files, hf_hub_download | |
| # --- CONFIGURATION --- | |
| # 1. Dataset Config | |
| DATASET_ID = "FrAnKu34t23/Herbarium_Field" | |
| DATASET_URL_BASE = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/train/herbarium/" | |
| SPECIES_LIST_URL = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/list/species_list.txt" | |
| # 2. Model Repo Config | |
| MODEL_REPO_ID = "FrAnKu34t23/ensemble_models_plant" | |
| INDEX_FILENAME = "herbarium_index.pkl" | |
| # Global Variables | |
| REFERENCE_IMAGE_MAP = {} # Fallback (Class ID -> Image Filename) | |
| NAME_TO_ID_MAP = {} # Lookup (Species Name -> Class ID) | |
| VECTOR_INDEX = None # Smart Search Index | |
| FEATURE_EXTRACTOR = None # DINOv2 model | |
| TRANSFORM = None # Image transforms | |
| # --- SETUP: Load Resources --- | |
| def load_resources(): | |
| global VECTOR_INDEX, FEATURE_EXTRACTOR, TRANSFORM | |
| print("π App starting... Initializing resources.") | |
| # 1. Load Name-to-ID Map (Crucial if models output only names) | |
| load_species_mapping() | |
| # 2. Download and Load Visual Search Index | |
| try: | |
| print(f"β¬οΈ Downloading {INDEX_FILENAME} from {MODEL_REPO_ID}...") | |
| index_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=INDEX_FILENAME, | |
| repo_type="model" | |
| ) | |
| print(f"β Downloaded index. Loading pickle...") | |
| with open(index_path, "rb") as f: | |
| VECTOR_INDEX = pickle.load(f) | |
| # Load DINOv2 | |
| print("β¬οΈ Loading DINOv2 (Retrieval Engine)...") | |
| FEATURE_EXTRACTOR = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') | |
| FEATURE_EXTRACTOR.eval() | |
| TRANSFORM = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| print("π Smart Search Ready!") | |
| except Exception as e: | |
| print(f"β οΈ Smart Search initialization failed: {e}") | |
| VECTOR_INDEX = None | |
| # 3. Build Fallback Map | |
| build_fallback_map() | |
| def load_species_mapping(): | |
| global NAME_TO_ID_MAP | |
| print("β¬οΈ Downloading species_list.txt for Name mapping...") | |
| try: | |
| # Fetch the text file from the dataset | |
| response = requests.get(SPECIES_LIST_URL) | |
| if response.status_code == 200: | |
| lines = response.text.splitlines() | |
| count = 0 | |
| for line in lines: | |
| # Assuming format: "ClassID;SpeciesName" or "ClassID SpeciesName" | |
| # Adjust splitting based on your actual file format | |
| parts = re.split(r'[;\t,]', line) | |
| if len(parts) >= 2: | |
| # Try to identify which part is the ID (digits) and which is the Name | |
| part1 = parts[0].strip() | |
| part2 = parts[1].strip() | |
| if part1.isdigit(): | |
| c_id, c_name = part1, part2 | |
| else: | |
| c_id, c_name = part2, part1 # Swap if ID is second | |
| # Store mapping: Name -> ID | |
| # Normalize name (lowercase) for easier matching | |
| NAME_TO_ID_MAP[c_name.lower()] = c_id | |
| count += 1 | |
| print(f"β Loaded {count} species names into mapping.") | |
| else: | |
| print(f"β οΈ Failed to download species list. Status: {response.status_code}") | |
| except Exception as e: | |
| print(f"β οΈ Error loading species list: {e}") | |
| def build_fallback_map(): | |
| global REFERENCE_IMAGE_MAP | |
| try: | |
| print(f"π Scanning dataset {DATASET_ID} for fallback map...") | |
| all_files = list_repo_files(repo_id=DATASET_ID, repo_type="dataset") | |
| # Look for images in: train/herbarium/{class_id}/{filename} | |
| image_files = [f for f in all_files if f.startswith("train/herbarium/") and f.lower().endswith(('.jpg', '.png'))] | |
| for file_path in image_files: | |
| parts = file_path.split("/") | |
| if len(parts) >= 4: | |
| class_id = parts[2] | |
| filename = parts[3] | |
| if class_id not in REFERENCE_IMAGE_MAP: | |
| REFERENCE_IMAGE_MAP[class_id] = filename | |
| print(f"β Fallback map built for {len(REFERENCE_IMAGE_MAP)} classes.") | |
| except Exception as e: | |
| print(f"β οΈ Error scanning dataset: {e}") | |
| # Load resources on startup | |
| load_resources() | |
| # --- Logic: ID Extraction & Search --- | |
| def get_class_id_from_prediction(class_prediction): | |
| """ | |
| Extracts Class ID from various formats, including pure Name lookups. | |
| """ | |
| if not class_prediction: return None | |
| prediction_str = str(class_prediction).strip() | |
| # 1. Check for explicit ID in string (e.g. "Name (12345)") | |
| match = re.search(r'\((\d+)\)', prediction_str) | |
| if match: return match.group(1) | |
| # 2. Check if the string IS the ID (e.g. "12345") | |
| if prediction_str.isdigit(): | |
| return prediction_str | |
| # 3. Check for "ID - Name" format | |
| match_start = re.match(r'^(\d+)\s+', prediction_str) | |
| if match_start: return match_start.group(1) | |
| # 4. NAME LOOKUP (New Feature) | |
| # If no numbers found, assume it's a name and look it up | |
| clean_name = prediction_str.lower().strip() | |
| if clean_name in NAME_TO_ID_MAP: | |
| return NAME_TO_ID_MAP[clean_name] | |
| return None | |
| def find_most_similar_herbarium_sheet(class_prediction, input_pil_image): | |
| class_id = get_class_id_from_prediction(class_prediction) | |
| if not class_id: | |
| print(f"β οΈ Could not resolve Class ID for: '{class_prediction}'") | |
| return None | |
| # Strategy A: Visual Similarity (Vectors) | |
| if VECTOR_INDEX and FEATURE_EXTRACTOR and input_pil_image and class_id in VECTOR_INDEX: | |
| try: | |
| img_tensor = TRANSFORM(input_pil_image).unsqueeze(0) | |
| with torch.no_grad(): | |
| input_vec = FEATURE_EXTRACTOR(img_tensor) | |
| input_vec = torch.nn.functional.normalize(input_vec, p=2, dim=1) | |
| candidates = VECTOR_INDEX[class_id] | |
| best_score = -1.0 | |
| best_filename = None | |
| for item in candidates: | |
| score = torch.mm(input_vec, item["vector"].T).item() | |
| if score > best_score: | |
| best_score = score | |
| best_filename = item["filename"] | |
| if best_filename: | |
| return f"{DATASET_URL_BASE}{class_id}/{best_filename}" | |
| except Exception as e: | |
| print(f"β οΈ Search failed: {e}") | |
| # Strategy B: Fallback | |
| filename = REFERENCE_IMAGE_MAP.get(class_id) | |
| if filename: | |
| return f"{DATASET_URL_BASE}{class_id}/{filename}" | |
| return None | |
| # --- Import User Models --- | |
| try: | |
| from baseline.baseline_convnext import predict_convnext | |
| except ImportError: | |
| def predict_convnext(image): return {"Error: ConvNeXt missing": 0.0} | |
| try: | |
| from baseline.baseline_infer import predict_baseline | |
| except ImportError: | |
| def predict_baseline(image): return {"Error: Baseline missing": 0.0} | |
| try: | |
| from new_approach.spa_ensemble import predict_spa | |
| except ImportError: | |
| def predict_spa(image): return {"Error: SPA missing": 0.0} | |
| def predict_placeholder_2(image): return {"Model 4 Not Available": 0.0} | |
| # --- Main App Logic --- | |
| def predict(model_choice, image): | |
| if image is None: return None, None | |
| # STEP 1: CLASSIFICATION | |
| predictions = {} | |
| if model_choice == "Herbarium Species Classifier (ConvNeXT)": | |
| predictions = predict_convnext(image) | |
| elif model_choice == "Baseline (DINOv2 + LogReg)": | |
| predictions = predict_baseline(image) | |
| elif model_choice == "SPA Ensemble (New Approach)": | |
| predictions = predict_spa(image) | |
| elif model_choice == "Future Model 2 (Placeholder)": | |
| predictions = predict_placeholder_2(image) | |
| else: | |
| predictions = {"Invalid model": 0.0} | |
| # Handle case where model returns a String instead of Dict | |
| top_class_str = None | |
| if isinstance(predictions, dict) and predictions: | |
| top_class_str = max(predictions, key=predictions.get) | |
| elif isinstance(predictions, str): | |
| top_class_str = predictions | |
| # STEP 2: RETRIEVAL | |
| reference_image_url = None | |
| if top_class_str and "Error" not in top_class_str and "Please" not in top_class_str: | |
| try: | |
| reference_image_url = find_most_similar_herbarium_sheet(top_class_str, image) | |
| except Exception as e: | |
| print(f"Error in retrieval: {e}") | |
| return predictions, reference_image_url | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo: | |
| with gr.Column(elem_id="app-wrapper"): | |
| gr.Markdown( | |
| """ | |
| <div id="app-header"> | |
| <h1>πΏ Plant Species Classification</h1> | |
| <h3>AML Group Project β Group 8</h3> | |
| </div> | |
| """, elem_id="app-header" | |
| ) | |
| with gr.Row(elem_id="main-card"): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| label="Select model", | |
| choices=[ | |
| "Herbarium Species Classifier (ConvNeXT)", | |
| "Baseline (DINOv2 + LogReg)", | |
| "SPA Ensemble (New Approach)", | |
| "Future Model 2 (Placeholder)", | |
| ], | |
| value="SPA Ensemble (New Approach)", | |
| ) | |
| gr.Markdown( | |
| """ | |
| <div id="model-help"> | |
| <b>Herbarium Classifier</b> β ConvNeXtV2 CNN.<br> | |
| <b>Baseline</b> β Simple DINOv2 + LogReg.<br> | |
| <b>SPA Ensemble</b> β <i>(New)</i> DINOv2 + BioCLIP-2 + Handcrafted features. | |
| </div> | |
| """, elem_id="model-help" | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload plant image") | |
| submit_button = gr.Button("Classify π±", variant="primary") | |
| with gr.Column(scale=1): | |
| output_label = gr.Label(label="Top 5 predictions", num_top_classes=5) | |
| herbarium_output = gr.Image( | |
| label="Matched Herbarium Specimen (Visual Reference)", | |
| show_label=True, | |
| interactive=False, | |
| height=300 | |
| ) | |
| submit_button.click( | |
| fn=predict, | |
| inputs=[model_selector, image_input], | |
| outputs=[output_label, herbarium_output], | |
| ) | |
| gr.Markdown("Built for the AML course β Group 8", elem_id="footer") | |
| if __name__ == "__main__": | |
| demo.launch() |