FrAnKu34t23 commited on
Commit
e7bed33
·
verified ·
1 Parent(s): 2423cc3

Upload spa_ensemble.py

Browse files
Files changed (1) hide show
  1. new_approach/spa_ensemble.py +351 -0
new_approach/spa_ensemble.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+ import os
8
+ from pathlib import Path
9
+ from scipy import stats
10
+ from scipy.fftpack import dct
11
+ from sklearn.preprocessing import StandardScaler
12
+ import torchvision.transforms as transforms
13
+ import open_clip
14
+ import joblib
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ # --- CONFIGURATION ---
18
+ CONFIDENCE_THRESHOLD = 0.99
19
+ # The list directory remains in the root of the Space
20
+ LIST_DIR = Path("list")
21
+
22
+ # ==============================================================================
23
+ # 1. FEATURE EXTRACTOR
24
+ # ==============================================================================
25
+ class FeatureExtractor:
26
+ @staticmethod
27
+ def extract_color_features(img):
28
+ img_np = np.array(img); features = {}
29
+ for i, channel in enumerate(['R', 'G', 'B']):
30
+ ch = img_np[:, :, i].flatten()
31
+ if len(ch) > 0:
32
+ features.update({f'color_{channel}_mean': float(np.mean(ch)), f'color_{channel}_std': float(np.std(ch)), f'color_{channel}_skew': float(stats.skew(ch)), f'color_{channel}_min': float(np.min(ch)), f'color_{channel}_max': float(np.max(ch))})
33
+ else:
34
+ features.update({f'color_{channel}_mean': 0.0, f'color_{channel}_std': 0.0, f'color_{channel}_skew': 0.0, f'color_{channel}_min': 0.0, f'color_{channel}_max': 0.0})
35
+
36
+ # --- FIX: Removed Histogram extraction (9 features) to match the 40 features expected by your .pth files ---
37
+ # hist, _ = np.histogram(ch, bins=3, range=(0, 256)); hist = hist / (hist.sum() + 1e-8);
38
+ # for j, v in enumerate(hist): features[f'color_{channel}_hist_bin{j}'] = float(v)
39
+
40
+ try:
41
+ hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
42
+ features.update({'color_hue_mean': float(np.mean(hsv[:, :, 0])), 'color_saturation_mean': float(np.mean(hsv[:, :, 1])), 'color_value_mean': float(np.mean(hsv[:, :, 2]))})
43
+ except:
44
+ features.update({'color_hue_mean': 0.0, 'color_saturation_mean': 0.0, 'color_value_mean': 0.0})
45
+ return features
46
+
47
+ @staticmethod
48
+ def extract_texture_features(img):
49
+ img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}
50
+ # Optimization: Canny/Sobel can be slow on huge images.
51
+ # We assume image is resized in extract_all_features
52
+ edges = cv2.Canny(gray, 50, 150)
53
+ gx, gy = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3), cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
54
+ features.update({
55
+ 'texture_edge_density': float(np.sum(edges > 0) / edges.size) if edges.size > 0 else 0.0,
56
+ 'texture_gradient_mean': float(np.mean(np.sqrt(gx**2 + gy**2))),
57
+ 'texture_gradient_std': float(np.std(np.sqrt(gx**2 + gy**2))),
58
+ 'texture_laplacian_var': float(np.var(cv2.Laplacian(gray, cv2.CV_64F)))
59
+ })
60
+ return features
61
+
62
+ @staticmethod
63
+ def extract_shape_features(img):
64
+ w, h = img.size; features = {}; features.update({'shape_height': h, 'shape_width': w, 'shape_aspect_ratio': w / h if h > 0 else 0.0, 'shape_area': w * h}); return features
65
+
66
+ @staticmethod
67
+ def extract_brightness_features(img):
68
+ img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); features = {}; features.update({'brightness_mean': float(np.mean(gray)), 'brightness_std': float(np.std(gray))}); return features
69
+
70
+ @staticmethod
71
+ def extract_frequency_features(img):
72
+ img_np = np.array(img)
73
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
74
+ gray_small = cv2.resize(gray, (64, 64))
75
+ dct_coeffs = dct(dct(gray_small.T, norm='ortho').T, norm='ortho')
76
+ features = {}
77
+ # FIX: Loop must finish before returning!
78
+ for i, v in enumerate(dct_coeffs.flatten()[:10]):
79
+ features[f'freq_dct_{i}'] = float(v)
80
+ return features
81
+
82
+ @staticmethod
83
+ def extract_statistical_features(img):
84
+ img_np = np.array(img); gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY); hist, _ = np.histogram(gray.flatten(), bins=256, range=(0, 256)); hist = hist / (hist.sum() + 1e-8)
85
+ hist_nonzero = hist[hist > 0]; entropy = -np.sum(hist_nonzero * np.log2(hist_nonzero)) if hist_nonzero.size > 0 else 0.0; features = {}; features.update({'stat_entropy': entropy, 'stat_uniformity': float(np.sum(hist**2))}); return features
86
+
87
+ @staticmethod
88
+ def extract_all_features(img):
89
+ img = img.convert('RGB')
90
+ # OPTIMIZATION: Resize for Handcrafted Features to speed up Canny/Sobel
91
+ max_size = 1024
92
+ if max(img.size) > max_size:
93
+ img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
94
+
95
+ features = {}
96
+ features.update(FeatureExtractor.extract_color_features(img))
97
+ features.update(FeatureExtractor.extract_texture_features(img))
98
+ features.update(FeatureExtractor.extract_shape_features(img))
99
+ features.update(FeatureExtractor.extract_brightness_features(img))
100
+ features.update(FeatureExtractor.extract_frequency_features(img))
101
+ features.update(FeatureExtractor.extract_statistical_features(img))
102
+ return features
103
+
104
+ # ==============================================================================
105
+ # 2. MODEL ARCHITECTURE
106
+ # ==============================================================================
107
+ class BioCLIP2ZeroShot:
108
+ def __init__(self, device, class_to_idx, id_to_name):
109
+ self.device = device; self.num_classes = len(class_to_idx); self.idx_to_class = {v: k for k, v in class_to_idx.items()}; self.id_to_name = id_to_name
110
+ print("Loading BioCLIP-2 model...")
111
+ try:
112
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
113
+ self.tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip-2')
114
+ except:
115
+ print("Warning: BioCLIP-2 load failed, trying base BioCLIP...")
116
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
117
+ self.tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip')
118
+ self.model.to(self.device).eval()
119
+ self.text_features_prototypes = self._precompute_text_features()
120
+
121
+ def _precompute_text_features(self):
122
+ templates = [ "a photo of {}", "a herbarium specimen of {}", "a botanical photograph of {}", "{} plant species", "leaves and flowers of {}" ]
123
+ class_ids = [self.idx_to_class[i] for i in range(self.num_classes)]
124
+ class_names = [self.id_to_name.get(str(cid), str(cid)) for cid in class_ids]
125
+ all_emb = []; bs = 64
126
+ text_inputs = [t.format(name) for name in class_names for t in templates]
127
+ with torch.no_grad():
128
+ for i in range(0, len(text_inputs), bs):
129
+ tokens = self.tokenizer(text_inputs[i:i+bs]).to(self.device)
130
+ emb = self.model.encode_text(tokens)
131
+ all_emb.append(emb)
132
+ all_text_embs = torch.cat(all_emb, dim=0).cpu().numpy()
133
+ prototypes = np.zeros((self.num_classes, all_text_embs.shape[1]), dtype=np.float32)
134
+ for idx in range(self.num_classes):
135
+ start = idx * len(templates)
136
+ avg = np.mean(all_text_embs[start:start + len(templates)], axis=0)
137
+ norm = np.linalg.norm(avg)
138
+ prototypes[idx] = avg / norm if norm > 0 else avg
139
+ return torch.from_numpy(prototypes).to(self.device)
140
+
141
+ def predict_zero_shot_logits(self, img):
142
+ processed = self.preprocess(img).unsqueeze(0).to(self.device)
143
+ with torch.no_grad():
144
+ image_features = self.model.encode_image(processed)
145
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
146
+ prototypes = self.text_features_prototypes
147
+ try: logit_scale = self.model.logit_scale.exp()
148
+ except: logit_scale = torch.tensor(100.0).to(self.device)
149
+ # --- FIX: Added .detach() before .numpy() ---
150
+ logits = (logit_scale * image_features @ prototypes.T).detach().cpu().numpy().squeeze()
151
+ return logits
152
+
153
+ class EnsembleClassifier(nn.Module):
154
+ def __init__(self, num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=100,
155
+ num_classes=100, hidden_dim=512, dropout_rate=0.3, prototype_dim=768):
156
+ super().__init__()
157
+ self.dinov2_proj = nn.Sequential(nn.Linear(dinov2_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate))
158
+
159
+ # --- FIX: Removed 3rd layer to match training checkpoint (Size mismatch error) ---
160
+ self.handcraft_branch = nn.Sequential(
161
+ nn.Linear(num_handcrafted_features, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate),
162
+ nn.Linear(128, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate)
163
+ )
164
+
165
+ self.bioclip2_branch = nn.Sequential(
166
+ nn.Linear(bioclip2_dim, hidden_dim // 4), nn.BatchNorm1d(hidden_dim // 4), nn.ReLU(), nn.Dropout(dropout_rate * 0.5))
167
+ fusion_input_dim = hidden_dim + hidden_dim // 2 + hidden_dim // 4
168
+ self.fusion = nn.Sequential(
169
+ nn.Linear(fusion_input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate))
170
+ self.classifier = nn.Linear(hidden_dim, num_classes)
171
+ self.prototype_proj = nn.Linear(hidden_dim, prototype_dim)
172
+
173
+ def forward(self, handcrafted_features, dinov2_features, bioclip2_logits):
174
+ dinov2_out = self.dinov2_proj(dinov2_features)
175
+ handcraft_out = self.handcraft_branch(handcrafted_features)
176
+ bioclip2_out = self.bioclip2_branch(bioclip2_logits)
177
+ shared_features = self.fusion(torch.cat([dinov2_out, handcraft_out, bioclip2_out], dim=1))
178
+ class_output = self.classifier(shared_features)
179
+ projected_feature = self.prototype_proj(shared_features)
180
+ return class_output, projected_feature
181
+
182
+ # ==============================================================================
183
+ # 3. MANAGER CLASS & EXPORTED FUNCTION
184
+ # ==============================================================================
185
+ class ModelManager:
186
+ def __init__(self):
187
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
188
+ print(f"Initializing SPA Ensemble on {self.device}...")
189
+
190
+ # --- CONFIG: YOUR MODEL REPO ID ---
191
+ # Using the correct repo ID provided
192
+ self.REPO_ID = "FrAnKu34t23/ensemble_models_plant"
193
+
194
+ self.class_to_idx, self.idx_to_class, self.id_to_name = self.load_class_info()
195
+ self.num_classes = len(self.class_to_idx)
196
+ print(f"SPA Ensemble: Loaded {self.num_classes} classes.")
197
+
198
+ # 1. Load DINOv2
199
+ print("SPA Ensemble: Loading DINOv2...")
200
+ self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
201
+ self.dinov2.to(self.device).eval()
202
+ self.dinov2_transform = transforms.Compose([
203
+ transforms.Resize(256), transforms.CenterCrop(224),
204
+ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
205
+ ])
206
+
207
+ # 2. Load BioCLIP
208
+ self.bioclip = BioCLIP2ZeroShot(self.device, self.class_to_idx, self.id_to_name)
209
+
210
+ # 3. Download & Load Scaler
211
+ print("SPA Ensemble: Downloading Scaler...")
212
+ try:
213
+ # Now fetching scaler.joblib from the Model Repo
214
+ scaler_path = hf_hub_download(repo_id=self.REPO_ID, filename="scaler.joblib")
215
+ self.scaler = joblib.load(scaler_path)
216
+ print("✓ Scaler downloaded and loaded.")
217
+ except Exception as e:
218
+ print(f"Warning: Could not download scaler from {self.REPO_ID}: {e}.")
219
+ print("Using dummy scaler (predictions may be inaccurate).")
220
+ self.scaler = StandardScaler()
221
+ # FIX: Fit on 40 zeros instead of 49 to match the feature reduction
222
+ self.scaler.fit(np.zeros((1, 40)))
223
+
224
+ # 4. Download & Load Ensemble Models
225
+ self.models = []
226
+ hidden_dims = [384, 448, 512, 576, 640]
227
+ dropout_rates = [0.2, 0.25, 0.3, 0.35, 0.4]
228
+
229
+ print(f"SPA Ensemble: Downloading Models from {self.REPO_ID}...")
230
+ for i in range(5):
231
+ filename = f"ensemble_model_{i}.pth"
232
+ try:
233
+ # Download
234
+ model_path = hf_hub_download(repo_id=self.REPO_ID, filename=filename)
235
+
236
+ # Load
237
+ # FIX: Passed num_handcrafted_features=40 and prototype_dim=768 to match weights
238
+ model = EnsembleClassifier(
239
+ num_handcrafted_features=40, dinov2_dim=1024, bioclip2_dim=self.num_classes,
240
+ num_classes=self.num_classes, hidden_dim=hidden_dims[i], dropout_rate=dropout_rates[i],
241
+ prototype_dim=768
242
+ )
243
+ state_dict = torch.load(model_path, map_location=self.device)
244
+ model.load_state_dict(state_dict)
245
+ model.to(self.device).eval()
246
+ self.models.append(model)
247
+ print(f"✓ Loaded {filename}")
248
+ except Exception as e:
249
+ print(f"Failed to load {filename}: {e}")
250
+
251
+ def load_class_info(self):
252
+ class_to_idx = {}
253
+ id_to_name = {}
254
+
255
+ species_path = LIST_DIR / "species_list.txt"
256
+ train_path = LIST_DIR / "train.txt"
257
+
258
+ classes_set = set()
259
+
260
+ if train_path.exists():
261
+ with open(train_path, 'r') as f:
262
+ for line in f:
263
+ parts = line.strip().split()
264
+ if len(parts) >= 2: classes_set.add(parts[1])
265
+ elif species_path.exists():
266
+ with open(species_path, 'r') as f:
267
+ for line in f:
268
+ parts = line.strip().split(";", 1)
269
+ classes_set.add(parts[0].strip())
270
+ else:
271
+ classes_set = {str(i) for i in range(100)}
272
+
273
+ sorted_classes = sorted(list(classes_set))
274
+ class_to_idx = {cls: idx for idx, cls in enumerate(sorted_classes)}
275
+ idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
276
+
277
+ if species_path.exists():
278
+ with open(species_path, 'r') as f:
279
+ for line in f:
280
+ if ";" in line:
281
+ parts = line.strip().split(";", 1)
282
+ id_to_name[parts[0].strip()] = parts[1].strip()
283
+ return class_to_idx, idx_to_class, id_to_name
284
+
285
+ def predict(self, image):
286
+ if image is None: return {}
287
+ img_pil = image.convert("RGB")
288
+
289
+ # 1. Handcrafted Features
290
+ hc_feats = FeatureExtractor.extract_all_features(img_pil)
291
+ hc_vector = np.array([hc_feats[k] for k in sorted(hc_feats.keys())]).reshape(1, -1)
292
+ hc_vector = self.scaler.transform(hc_vector)
293
+ hc_tensor = torch.FloatTensor(hc_vector).to(self.device)
294
+
295
+ # 2. DINOv2 Features
296
+ dino_input = self.dinov2_transform(img_pil).unsqueeze(0).to(self.device)
297
+ with torch.no_grad():
298
+ dino_feats = self.dinov2(dino_input)
299
+ dino_feats = dino_feats / (dino_feats.norm(dim=-1, keepdim=True) + 1e-8)
300
+
301
+ # 3. BioCLIP Features
302
+ bioclip_logits = self.bioclip.predict_zero_shot_logits(img_pil)
303
+ bioclip_tensor = torch.FloatTensor(bioclip_logits).unsqueeze(0).to(self.device)
304
+
305
+ # 4. Ensemble Prediction
306
+ all_probs = []
307
+ if not self.models: return {"Error": "SPA Models not loaded"}
308
+
309
+ for model in self.models:
310
+ with torch.no_grad():
311
+ probs, _ = model(hc_tensor, dino_feats, bioclip_tensor)
312
+ probs = F.softmax(probs, dim=1).cpu().numpy()[0]
313
+ all_probs.append(probs)
314
+
315
+ final_ens_probs = np.mean(all_probs, axis=0)
316
+
317
+ # 5. Hybrid Routing
318
+ exp_logits = np.exp(bioclip_logits)
319
+ bioclip_probs = exp_logits / np.sum(exp_logits)
320
+
321
+ ens_pred_idx = np.argmax(final_ens_probs)
322
+ ens_conf = final_ens_probs[ens_pred_idx]
323
+
324
+ if ens_conf < CONFIDENCE_THRESHOLD:
325
+ final_probs = (final_ens_probs + bioclip_probs) / 2
326
+ else:
327
+ final_probs = final_ens_probs
328
+
329
+ # 6. Formatting
330
+ top_k = 5
331
+ top_indices = np.argsort(final_probs)[::-1][:top_k]
332
+ results = {}
333
+ for idx in top_indices:
334
+ class_id = self.idx_to_class[idx]
335
+ name = self.id_to_name.get(class_id, class_id)
336
+ score = float(final_probs[idx])
337
+ results[f"{name} ({class_id})"] = score
338
+
339
+ return results
340
+
341
+ # Initialize Singleton
342
+ try:
343
+ spa_manager = ModelManager()
344
+ except Exception as e:
345
+ print(f"CRITICAL ERROR initializing SPA: {e}")
346
+ spa_manager = None
347
+
348
+ def predict_spa(image):
349
+ if spa_manager is None:
350
+ return {"Error": "SPA System failed to initialize."}
351
+ return spa_manager.predict(image)