Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import os | |
| from pathlib import Path | |
| from scipy import stats | |
| from scipy.fftpack import dct | |
| from sklearn.preprocessing import StandardScaler | |
| import torchvision.transforms as transforms | |
| import open_clip | |
| import joblib | |
| from huggingface_hub import hf_hub_download | |
| # --- CONFIGURATION --- | |
| CONFIDENCE_THRESHOLD = 0.99 | |
| # The list directory remains in the root of the Space | |
| LIST_DIR = Path("list") | |
| # ============================================================================== | |
| # 1. FEATURE EXTRACTOR | |
| # ============================================================================== | |
| class FeatureExtractor: | |
| def extract_color_features(img): | |
| img_np = np.array(img); features = {} | |
| for i, channel in enumerate(['R', 'G', 'B']): | |
| ch = img_np[:, :, i].flatten() | |
| if len(ch) > 0: | |
| features.update({f'color_{channel}_mean': float(np.mean(ch)), f'color_{channel}_std': float(np.std(ch)), f'color_{channel}_skew': float(stats.skew(ch)), f'color_{channel}_min': float(np.min(ch)), f'color_{channel}_max': float(np.max(ch))}) | |
| else: | |
| features.update({f'color_{channel}_mean': 0.0, f'color_{channel}_std': 0.0, f'color_{channel}_skew': 0.0, f'color_{channel}_min': 0.0, f'color_{channel}_max': 0.0}) | |
| # --- FIX: Removed Histogram extraction (9 features) to match the 40 features expected by your .pth files --- | |
| # hist, _ = np.histogram(ch, bins=3, range=(0, 256)); hist = hist / (hist.sum() + 1e-8); | |
| # for j, v in enumerate(hist): features[f'color_{channel}_hist_bin{j}'] = float(v) | |
| try: | |
| hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV) | |
| features.update({'color_hue_mean': float(np.mean(hsv[:, :, 0])), 'color_saturation_mean': float(np.mean(hsv[:, :, 1])), 'color_value_mean': float(np.mean(hsv[:, :, 2]))}) | |
| except: | |
| features.update({'color_hue_mean': 0.0, 'color_saturation_mean': 0.0, 'color_value_mean': 0.0}) | |
| return features | |
| def extract_texture_features(img): | |
| img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {} | |
| # Optimization: Canny/Sobel can be slow on huge images. | |
| # We assume image is resized in extract_all_features | |
| edges = cv2.Canny(gray, 50, 150) | |
| gx, gy = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3), cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) | |
| features.update({ | |
| 'texture_edge_density': float(np.sum(edges > 0) / edges.size) if edges.size > 0 else 0.0, | |
| 'texture_gradient_mean': float(np.mean(np.sqrt(gx**2 + gy**2))), | |
| 'texture_gradient_std': float(np.std(np.sqrt(gx**2 + gy**2))), | |
| 'texture_laplacian_var': float(np.var(cv2.Laplacian(gray, cv2.CV_64F))) | |
| }) | |
| return features | |
| def extract_shape_features(img): | |
| w, h = img.size; features = {}; features.update({'shape_height': h, 'shape_width': w, 'shape_aspect_ratio': w / h if h > 0 else 0.0, 'shape_area': w * h}); return features | |
| def extract_brightness_features(img): | |
| img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}; features.update({'brightness_mean': float(np.mean(gray)), 'brightness_std': float(np.std(gray))}); return features | |
| def extract_frequency_features(img): | |
| img_np = np.array(img) | |
| gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) | |
| gray_small = cv2.resize(gray, (64, 64)) | |
| dct_coeffs = dct(dct(gray_small.T, norm='ortho').T, norm='ortho') | |
| features = {} | |
| # FIX: Loop must finish before returning! | |
| for i, v in enumerate(dct_coeffs.flatten()[:10]): | |
| features[f'freq_dct_{i}'] = float(v) | |
| return features | |
| def extract_statistical_features(img): | |
| img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); hist, _ = np.histogram(gray.flatten(), bins=256, range=(0, 256)); hist = hist / (hist.sum() + 1e-8) | |
| hist_nonzero = hist[hist > 0]; entropy = -np.sum(hist_nonzero * np.log2(hist_nonzero)) if hist_nonzero.size > 0 else 0.0; features = {}; features.update({'stat_entropy': entropy, 'stat_uniformity': float(np.sum(hist**2))}); return features | |
| def extract_all_features(img): | |
| img = img.convert('RGB') | |
| # OPTIMIZATION: Resize for Handcrafted Features to speed up Canny/Sobel | |
| max_size = 1024 | |
| if max(img.size) > max_size: | |
| img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| features = {} | |
| features.update(FeatureExtractor.extract_color_features(img)) | |
| features.update(FeatureExtractor.extract_texture_features(img)) | |
| features.update(FeatureExtractor.extract_shape_features(img)) | |
| features.update(FeatureExtractor.extract_brightness_features(img)) | |
| features.update(FeatureExtractor.extract_frequency_features(img)) | |
| features.update(FeatureExtractor.extract_statistical_features(img)) | |
| return features | |
| # ============================================================================== | |
| # 2. MODEL ARCHITECTURE | |
| # ============================================================================== | |
| class BioCLIP2ZeroShot: | |
| def __init__(self, device, class_to_idx, id_to_name): | |
| self.device = device; self.num_classes = len(class_to_idx); self.idx_to_class = {v: k for k, v in class_to_idx.items()}; self.id_to_name = id_to_name | |
| print("Loading BioCLIP-2 model...") | |
| try: | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2') | |
| self.tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip-2') | |
| except: | |
| print("Warning: BioCLIP-2 load failed, trying base BioCLIP...") | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip') | |
| self.tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip') | |
| self.model.to(self.device).eval() | |
| self.text_features_prototypes = self._precompute_text_features() | |
| def _precompute_text_features(self): | |
| templates = [ "a photo of {}", "a herbarium specimen of {}", "a botanical photograph of {}", "{} plant species", "leaves and flowers of {}" ] | |
| class_ids = [self.idx_to_class[i] for i in range(self.num_classes)] | |
| class_names = [self.id_to_name.get(str(cid), str(cid)) for cid in class_ids] | |
| all_emb = []; bs = 64 | |
| text_inputs = [t.format(name) for name in class_names for t in templates] | |
| with torch.no_grad(): | |
| for i in range(0, len(text_inputs), bs): | |
| tokens = self.tokenizer(text_inputs[i:i+bs]).to(self.device) | |
| emb = self.model.encode_text(tokens) | |
| all_emb.append(emb) | |
| all_text_embs = torch.cat(all_emb, dim=0).cpu().numpy() | |
| prototypes = np.zeros((self.num_classes, all_text_embs.shape[1]), dtype=np.float32) | |
| for idx in range(self.num_classes): | |
| start = idx * len(templates) | |
| avg = np.mean(all_text_embs[start:start + len(templates)], axis=0) | |
| norm = np.linalg.norm(avg) | |
| prototypes[idx] = avg / norm if norm > 0 else avg | |
| return torch.from_numpy(prototypes).to(self.device) | |
| def predict_zero_shot_logits(self, img): | |
| processed = self.preprocess(img).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(processed) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| prototypes = self.text_features_prototypes | |
| try: logit_scale = self.model.logit_scale.exp() | |
| except: logit_scale = torch.tensor(100.0).to(self.device) | |
| # --- FIX: Added .detach() before .numpy() --- | |
| logits = (logit_scale * image_features @ prototypes.T).detach().cpu().numpy().squeeze() | |
| return logits | |
| class EnsembleClassifier(nn.Module): | |
| def __init__(self, num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=100, | |
| num_classes=100, hidden_dim=512, dropout_rate=0.3, prototype_dim=768): | |
| super().__init__() | |
| self.dinov2_proj = nn.Sequential(nn.Linear(dinov2_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate)) | |
| # --- FIX: Removed 3rd layer to match training checkpoint (Size mismatch error) --- | |
| self.handcraft_branch = nn.Sequential( | |
| nn.Linear(num_handcrafted_features, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate), | |
| nn.Linear(128, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate) | |
| ) | |
| self.bioclip2_branch = nn.Sequential( | |
| nn.Linear(bioclip2_dim, hidden_dim // 4), nn.BatchNorm1d(hidden_dim // 4), nn.ReLU(), nn.Dropout(dropout_rate * 0.5)) | |
| fusion_input_dim = hidden_dim + hidden_dim // 2 + hidden_dim // 4 | |
| self.fusion = nn.Sequential( | |
| nn.Linear(fusion_input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate)) | |
| self.classifier = nn.Linear(hidden_dim, num_classes) | |
| self.prototype_proj = nn.Linear(hidden_dim, prototype_dim) | |
| def forward(self, handcrafted_features, dinov2_features, bioclip2_logits): | |
| dinov2_out = self.dinov2_proj(dinov2_features) | |
| handcraft_out = self.handcraft_branch(handcrafted_features) | |
| bioclip2_out = self.bioclip2_branch(bioclip2_logits) | |
| shared_features = self.fusion(torch.cat([dinov2_out, handcraft_out, bioclip2_out], dim=1)) | |
| class_output = self.classifier(shared_features) | |
| projected_feature = self.prototype_proj(shared_features) | |
| return class_output, projected_feature | |
| # ============================================================================== | |
| # 3. MANAGER CLASS & EXPORTED FUNCTION | |
| # ============================================================================== | |
| class ModelManager: | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Initializing SPA Ensemble on {self.device}...") | |
| # --- CONFIG: YOUR MODEL REPO ID --- | |
| # Using the correct repo ID provided | |
| self.REPO_ID = "FrAnKu34t23/ensemble_models_plant" | |
| self.class_to_idx, self.idx_to_class, self.id_to_name = self.load_class_info() | |
| self.num_classes = len(self.class_to_idx) | |
| print(f"SPA Ensemble: Loaded {self.num_classes} classes.") | |
| # 1. Load DINOv2 | |
| print("SPA Ensemble: Loading DINOv2...") | |
| self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') | |
| self.dinov2.to(self.device).eval() | |
| self.dinov2_transform = transforms.Compose([ | |
| transforms.Resize(256), transforms.CenterCrop(224), | |
| transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # 2. Load BioCLIP | |
| self.bioclip = BioCLIP2ZeroShot(self.device, self.class_to_idx, self.id_to_name) | |
| # 3. Download & Load Scaler | |
| print("SPA Ensemble: Downloading Scaler...") | |
| try: | |
| # Now fetching scaler.joblib from the Model Repo | |
| scaler_path = hf_hub_download(repo_id=self.REPO_ID, filename="scaler.joblib") | |
| self.scaler = joblib.load(scaler_path) | |
| print("✓ Scaler downloaded and loaded.") | |
| except Exception as e: | |
| print(f"Warning: Could not download scaler from {self.REPO_ID}: {e}.") | |
| print("Using dummy scaler (predictions may be inaccurate).") | |
| self.scaler = StandardScaler() | |
| # FIX: Fit on 40 zeros instead of 49 to match the feature reduction | |
| self.scaler.fit(np.zeros((1, 40))) | |
| # 4. Download & Load Ensemble Models | |
| self.models = [] | |
| hidden_dims = [384, 448, 512, 576, 640] | |
| dropout_rates = [0.2, 0.25, 0.3, 0.35, 0.4] | |
| print(f"SPA Ensemble: Downloading Models from {self.REPO_ID}...") | |
| for i in range(5): | |
| filename = f"ensemble_model_{i}.pth" | |
| try: | |
| # Download | |
| model_path = hf_hub_download(repo_id=self.REPO_ID, filename=filename) | |
| # Load | |
| # FIX: Passed num_handcrafted_features=40 and prototype_dim=768 to match weights | |
| model = EnsembleClassifier( | |
| num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=self.num_classes, | |
| num_classes=self.num_classes, hidden_dim=hidden_dims[i], dropout_rate=dropout_rates[i], | |
| prototype_dim=768 | |
| ) | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| model.load_state_dict(state_dict) | |
| model.to(self.device).eval() | |
| self.models.append(model) | |
| print(f"✓ Loaded {filename}") | |
| except Exception as e: | |
| print(f"Failed to load {filename}: {e}") | |
| def load_class_info(self): | |
| class_to_idx = {} | |
| id_to_name = {} | |
| species_path = LIST_DIR / "species_list.txt" | |
| train_path = LIST_DIR / "train.txt" | |
| classes_set = set() | |
| if train_path.exists(): | |
| with open(train_path, 'r') as f: | |
| for line in f: | |
| parts = line.strip().split() | |
| if len(parts) >= 2: classes_set.add(parts[1]) | |
| elif species_path.exists(): | |
| with open(species_path, 'r') as f: | |
| for line in f: | |
| parts = line.strip().split(";", 1) | |
| classes_set.add(parts[0].strip()) | |
| else: | |
| classes_set = {str(i) for i in range(100)} | |
| sorted_classes = sorted(list(classes_set)) | |
| class_to_idx = {cls: idx for idx, cls in enumerate(sorted_classes)} | |
| idx_to_class = {idx: cls for cls, idx in class_to_idx.items()} | |
| if species_path.exists(): | |
| with open(species_path, 'r') as f: | |
| for line in f: | |
| if ";" in line: | |
| parts = line.strip().split(";", 1) | |
| id_to_name[parts[0].strip()] = parts[1].strip() | |
| return class_to_idx, idx_to_class, id_to_name | |
| def predict(self, image): | |
| if image is None: return {} | |
| img_pil = image.convert("RGB") | |
| # 1. Handcrafted Features | |
| hc_feats = FeatureExtractor.extract_all_features(img_pil) | |
| hc_vector = np.array([hc_feats[k] for k in sorted(hc_feats.keys())]).reshape(1, -1) | |
| hc_vector = self.scaler.transform(hc_vector) | |
| hc_tensor = torch.FloatTensor(hc_vector).to(self.device) | |
| # 2. DINOv2 Features | |
| dino_input = self.dinov2_transform(img_pil).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| dino_feats = self.dinov2(dino_input) | |
| dino_feats = dino_feats / (dino_feats.norm(dim=-1, keepdim=True) + 1e-8) | |
| # 3. BioCLIP Features | |
| bioclip_logits = self.bioclip.predict_zero_shot_logits(img_pil) | |
| bioclip_tensor = torch.FloatTensor(bioclip_logits).unsqueeze(0).to(self.device) | |
| # 4. Ensemble Prediction | |
| all_probs = [] | |
| if not self.models: return {"Error": "SPA Models not loaded"} | |
| for model in self.models: | |
| with torch.no_grad(): | |
| probs, _ = model(hc_tensor, dino_feats, bioclip_tensor) | |
| probs = F.softmax(probs, dim=1).cpu().numpy()[0] | |
| all_probs.append(probs) | |
| final_ens_probs = np.mean(all_probs, axis=0) | |
| # 5. Hybrid Routing | |
| exp_logits = np.exp(bioclip_logits) | |
| bioclip_probs = exp_logits / np.sum(exp_logits) | |
| ens_pred_idx = np.argmax(final_ens_probs) | |
| ens_conf = final_ens_probs[ens_pred_idx] | |
| if ens_conf < CONFIDENCE_THRESHOLD: | |
| final_probs = (final_ens_probs + bioclip_probs) / 2 | |
| else: | |
| final_probs = final_ens_probs | |
| # 6. Formatting | |
| top_k = 5 | |
| top_indices = np.argsort(final_probs)[::-1][:top_k] | |
| results = {} | |
| for idx in top_indices: | |
| class_id = self.idx_to_class[idx] | |
| name = self.id_to_name.get(class_id, class_id) | |
| score = float(final_probs[idx]) | |
| results[f"{name} ({class_id})"] = score | |
| return results | |
| # Initialize Singleton | |
| try: | |
| spa_manager = ModelManager() | |
| except Exception as e: | |
| print(f"CRITICAL ERROR initializing SPA: {e}") | |
| spa_manager = None | |
| def predict_spa(image): | |
| if spa_manager is None: | |
| return {"Error": "SPA System failed to initialize."} | |
| return spa_manager.predict(image) |