import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import cv2 from PIL import Image import os from pathlib import Path from scipy import stats from scipy.fftpack import dct from sklearn.preprocessing import StandardScaler import torchvision.transforms as transforms import open_clip import joblib from huggingface_hub import hf_hub_download # --- CONFIGURATION --- CONFIDENCE_THRESHOLD = 0.99 # The list directory remains in the root of the Space LIST_DIR = Path("list") # ============================================================================== # 1. FEATURE EXTRACTOR # ============================================================================== class FeatureExtractor: @staticmethod def extract_color_features(img): img_np = np.array(img); features = {} for i, channel in enumerate(['R', 'G', 'B']): ch = img_np[:, :, i].flatten() if len(ch) > 0: features.update({f'color_{channel}_mean': float(np.mean(ch)), f'color_{channel}_std': float(np.std(ch)), f'color_{channel}_skew': float(stats.skew(ch)), f'color_{channel}_min': float(np.min(ch)), f'color_{channel}_max': float(np.max(ch))}) else: features.update({f'color_{channel}_mean': 0.0, f'color_{channel}_std': 0.0, f'color_{channel}_skew': 0.0, f'color_{channel}_min': 0.0, f'color_{channel}_max': 0.0}) # --- FIX: Removed Histogram extraction (9 features) to match the 40 features expected by your .pth files --- # hist, _ = np.histogram(ch, bins=3, range=(0, 256)); hist = hist / (hist.sum() + 1e-8); # for j, v in enumerate(hist): features[f'color_{channel}_hist_bin{j}'] = float(v) try: hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV) features.update({'color_hue_mean': float(np.mean(hsv[:, :, 0])), 'color_saturation_mean': float(np.mean(hsv[:, :, 1])), 'color_value_mean': float(np.mean(hsv[:, :, 2]))}) except: features.update({'color_hue_mean': 0.0, 'color_saturation_mean': 0.0, 'color_value_mean': 0.0}) return features @staticmethod def extract_texture_features(img): img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {} # Optimization: Canny/Sobel can be slow on huge images. # We assume image is resized in extract_all_features edges = cv2.Canny(gray, 50, 150) gx, gy = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3), cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) features.update({ 'texture_edge_density': float(np.sum(edges > 0) / edges.size) if edges.size > 0 else 0.0, 'texture_gradient_mean': float(np.mean(np.sqrt(gx**2 + gy**2))), 'texture_gradient_std': float(np.std(np.sqrt(gx**2 + gy**2))), 'texture_laplacian_var': float(np.var(cv2.Laplacian(gray, cv2.CV_64F))) }) return features @staticmethod def extract_shape_features(img): w, h = img.size; features = {}; features.update({'shape_height': h, 'shape_width': w, 'shape_aspect_ratio': w / h if h > 0 else 0.0, 'shape_area': w * h}); return features @staticmethod def extract_brightness_features(img): img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}; features.update({'brightness_mean': float(np.mean(gray)), 'brightness_std': float(np.std(gray))}); return features @staticmethod def extract_frequency_features(img): img_np = np.array(img) gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) gray_small = cv2.resize(gray, (64, 64)) dct_coeffs = dct(dct(gray_small.T, norm='ortho').T, norm='ortho') features = {} # FIX: Loop must finish before returning! for i, v in enumerate(dct_coeffs.flatten()[:10]): features[f'freq_dct_{i}'] = float(v) return features @staticmethod def extract_statistical_features(img): img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); hist, _ = np.histogram(gray.flatten(), bins=256, range=(0, 256)); hist = hist / (hist.sum() + 1e-8) hist_nonzero = hist[hist > 0]; entropy = -np.sum(hist_nonzero * np.log2(hist_nonzero)) if hist_nonzero.size > 0 else 0.0; features = {}; features.update({'stat_entropy': entropy, 'stat_uniformity': float(np.sum(hist**2))}); return features @staticmethod def extract_all_features(img): img = img.convert('RGB') # OPTIMIZATION: Resize for Handcrafted Features to speed up Canny/Sobel max_size = 1024 if max(img.size) > max_size: img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) features = {} features.update(FeatureExtractor.extract_color_features(img)) features.update(FeatureExtractor.extract_texture_features(img)) features.update(FeatureExtractor.extract_shape_features(img)) features.update(FeatureExtractor.extract_brightness_features(img)) features.update(FeatureExtractor.extract_frequency_features(img)) features.update(FeatureExtractor.extract_statistical_features(img)) return features # ============================================================================== # 2. MODEL ARCHITECTURE # ============================================================================== class BioCLIP2ZeroShot: def __init__(self, device, class_to_idx, id_to_name): self.device = device; self.num_classes = len(class_to_idx); self.idx_to_class = {v: k for k, v in class_to_idx.items()}; self.id_to_name = id_to_name print("Loading BioCLIP-2 model...") try: self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2') self.tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip-2') except: print("Warning: BioCLIP-2 load failed, trying base BioCLIP...") self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip') self.tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip') self.model.to(self.device).eval() self.text_features_prototypes = self._precompute_text_features() def _precompute_text_features(self): templates = [ "a photo of {}", "a herbarium specimen of {}", "a botanical photograph of {}", "{} plant species", "leaves and flowers of {}" ] class_ids = [self.idx_to_class[i] for i in range(self.num_classes)] class_names = [self.id_to_name.get(str(cid), str(cid)) for cid in class_ids] all_emb = []; bs = 64 text_inputs = [t.format(name) for name in class_names for t in templates] with torch.no_grad(): for i in range(0, len(text_inputs), bs): tokens = self.tokenizer(text_inputs[i:i+bs]).to(self.device) emb = self.model.encode_text(tokens) all_emb.append(emb) all_text_embs = torch.cat(all_emb, dim=0).cpu().numpy() prototypes = np.zeros((self.num_classes, all_text_embs.shape[1]), dtype=np.float32) for idx in range(self.num_classes): start = idx * len(templates) avg = np.mean(all_text_embs[start:start + len(templates)], axis=0) norm = np.linalg.norm(avg) prototypes[idx] = avg / norm if norm > 0 else avg return torch.from_numpy(prototypes).to(self.device) def predict_zero_shot_logits(self, img): processed = self.preprocess(img).unsqueeze(0).to(self.device) with torch.no_grad(): image_features = self.model.encode_image(processed) image_features = image_features / image_features.norm(dim=-1, keepdim=True) prototypes = self.text_features_prototypes try: logit_scale = self.model.logit_scale.exp() except: logit_scale = torch.tensor(100.0).to(self.device) # --- FIX: Added .detach() before .numpy() --- logits = (logit_scale * image_features @ prototypes.T).detach().cpu().numpy().squeeze() return logits class EnsembleClassifier(nn.Module): def __init__(self, num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=100, num_classes=100, hidden_dim=512, dropout_rate=0.3, prototype_dim=768): super().__init__() self.dinov2_proj = nn.Sequential(nn.Linear(dinov2_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate)) # --- FIX: Removed 3rd layer to match training checkpoint (Size mismatch error) --- self.handcraft_branch = nn.Sequential( nn.Linear(num_handcrafted_features, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(128, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate) ) self.bioclip2_branch = nn.Sequential( nn.Linear(bioclip2_dim, hidden_dim // 4), nn.BatchNorm1d(hidden_dim // 4), nn.ReLU(), nn.Dropout(dropout_rate * 0.5)) fusion_input_dim = hidden_dim + hidden_dim // 2 + hidden_dim // 4 self.fusion = nn.Sequential( nn.Linear(fusion_input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate)) self.classifier = nn.Linear(hidden_dim, num_classes) self.prototype_proj = nn.Linear(hidden_dim, prototype_dim) def forward(self, handcrafted_features, dinov2_features, bioclip2_logits): dinov2_out = self.dinov2_proj(dinov2_features) handcraft_out = self.handcraft_branch(handcrafted_features) bioclip2_out = self.bioclip2_branch(bioclip2_logits) shared_features = self.fusion(torch.cat([dinov2_out, handcraft_out, bioclip2_out], dim=1)) class_output = self.classifier(shared_features) projected_feature = self.prototype_proj(shared_features) return class_output, projected_feature # ============================================================================== # 3. MANAGER CLASS & EXPORTED FUNCTION # ============================================================================== class ModelManager: def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Initializing SPA Ensemble on {self.device}...") # --- CONFIG: YOUR MODEL REPO ID --- # Using the correct repo ID provided self.REPO_ID = "FrAnKu34t23/ensemble_models_plant" self.class_to_idx, self.idx_to_class, self.id_to_name = self.load_class_info() self.num_classes = len(self.class_to_idx) print(f"SPA Ensemble: Loaded {self.num_classes} classes.") # 1. Load DINOv2 print("SPA Ensemble: Loading DINOv2...") self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') self.dinov2.to(self.device).eval() self.dinov2_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 2. Load BioCLIP self.bioclip = BioCLIP2ZeroShot(self.device, self.class_to_idx, self.id_to_name) # 3. Download & Load Scaler print("SPA Ensemble: Downloading Scaler...") try: # Now fetching scaler.joblib from the Model Repo scaler_path = hf_hub_download(repo_id=self.REPO_ID, filename="scaler.joblib") self.scaler = joblib.load(scaler_path) print("✓ Scaler downloaded and loaded.") except Exception as e: print(f"Warning: Could not download scaler from {self.REPO_ID}: {e}.") print("Using dummy scaler (predictions may be inaccurate).") self.scaler = StandardScaler() # FIX: Fit on 40 zeros instead of 49 to match the feature reduction self.scaler.fit(np.zeros((1, 40))) # 4. Download & Load Ensemble Models self.models = [] hidden_dims = [384, 448, 512, 576, 640] dropout_rates = [0.2, 0.25, 0.3, 0.35, 0.4] print(f"SPA Ensemble: Downloading Models from {self.REPO_ID}...") for i in range(5): filename = f"ensemble_model_{i}.pth" try: # Download model_path = hf_hub_download(repo_id=self.REPO_ID, filename=filename) # Load # FIX: Passed num_handcrafted_features=40 and prototype_dim=768 to match weights model = EnsembleClassifier( num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=self.num_classes, num_classes=self.num_classes, hidden_dim=hidden_dims[i], dropout_rate=dropout_rates[i], prototype_dim=768 ) state_dict = torch.load(model_path, map_location=self.device) model.load_state_dict(state_dict) model.to(self.device).eval() self.models.append(model) print(f"✓ Loaded {filename}") except Exception as e: print(f"Failed to load {filename}: {e}") def load_class_info(self): class_to_idx = {} id_to_name = {} species_path = LIST_DIR / "species_list.txt" train_path = LIST_DIR / "train.txt" classes_set = set() if train_path.exists(): with open(train_path, 'r') as f: for line in f: parts = line.strip().split() if len(parts) >= 2: classes_set.add(parts[1]) elif species_path.exists(): with open(species_path, 'r') as f: for line in f: parts = line.strip().split(";", 1) classes_set.add(parts[0].strip()) else: classes_set = {str(i) for i in range(100)} sorted_classes = sorted(list(classes_set)) class_to_idx = {cls: idx for idx, cls in enumerate(sorted_classes)} idx_to_class = {idx: cls for cls, idx in class_to_idx.items()} if species_path.exists(): with open(species_path, 'r') as f: for line in f: if ";" in line: parts = line.strip().split(";", 1) id_to_name[parts[0].strip()] = parts[1].strip() return class_to_idx, idx_to_class, id_to_name def predict(self, image): if image is None: return {} img_pil = image.convert("RGB") # 1. Handcrafted Features hc_feats = FeatureExtractor.extract_all_features(img_pil) hc_vector = np.array([hc_feats[k] for k in sorted(hc_feats.keys())]).reshape(1, -1) hc_vector = self.scaler.transform(hc_vector) hc_tensor = torch.FloatTensor(hc_vector).to(self.device) # 2. DINOv2 Features dino_input = self.dinov2_transform(img_pil).unsqueeze(0).to(self.device) with torch.no_grad(): dino_feats = self.dinov2(dino_input) dino_feats = dino_feats / (dino_feats.norm(dim=-1, keepdim=True) + 1e-8) # 3. BioCLIP Features bioclip_logits = self.bioclip.predict_zero_shot_logits(img_pil) bioclip_tensor = torch.FloatTensor(bioclip_logits).unsqueeze(0).to(self.device) # 4. Ensemble Prediction all_probs = [] if not self.models: return {"Error": "SPA Models not loaded"} for model in self.models: with torch.no_grad(): probs, _ = model(hc_tensor, dino_feats, bioclip_tensor) probs = F.softmax(probs, dim=1).cpu().numpy()[0] all_probs.append(probs) final_ens_probs = np.mean(all_probs, axis=0) # 5. Hybrid Routing exp_logits = np.exp(bioclip_logits) bioclip_probs = exp_logits / np.sum(exp_logits) ens_pred_idx = np.argmax(final_ens_probs) ens_conf = final_ens_probs[ens_pred_idx] if ens_conf < CONFIDENCE_THRESHOLD: final_probs = (final_ens_probs + bioclip_probs) / 2 else: final_probs = final_ens_probs # 6. Formatting top_k = 5 top_indices = np.argsort(final_probs)[::-1][:top_k] results = {} for idx in top_indices: class_id = self.idx_to_class[idx] name = self.id_to_name.get(class_id, class_id) score = float(final_probs[idx]) results[f"{name} ({class_id})"] = score return results # Initialize Singleton try: spa_manager = ModelManager() except Exception as e: print(f"CRITICAL ERROR initializing SPA: {e}") spa_manager = None def predict_spa(image): if spa_manager is None: return {"Error": "SPA System failed to initialize."} return spa_manager.predict(image)