""" Model loader - Handles VGG16 loading and inference """ import torch import torch.nn as nn from torchvision import models, transforms from .config import IMAGE_SIZE, NORMALIZE_MEAN, NORMALIZE_STD class VGG16ModelLoader: """Loads and manages VGG16 model for fish disease detection""" def __init__(self, model_path, num_classes, device='cpu'): """ Initialize model loader Args: model_path: Path to trained model weights (.pth file) num_classes: Number of disease classes device: 'cpu' or 'cuda' """ self.model_path = model_path self.num_classes = num_classes self.device = torch.device(device) self.model = None self.transform = None self._load_model() self._setup_transform() def _load_model(self): """Load VGG16 with custom classifier""" try: # Create VGG16 architecture self.model = models.vgg16(weights="IMAGENET1K_V1") # Replace final layer for our classes self.model.classifier[6] = nn.Linear(4096, self.num_classes) # Load trained weights state_dict = torch.load( self.model_path, map_location=self.device, weights_only=True # Security: only load weights ) self.model.load_state_dict(state_dict) # Move to device and set eval mode self.model = self.model.to(self.device) self.model.eval() print(f"✅ Model loaded: {self.model_path}") print(f"✅ Device: {self.device}") except FileNotFoundError: raise RuntimeError(f"Model file not found: {self.model_path}") except Exception as e: raise RuntimeError(f"Failed to load model: {e}") def _setup_transform(self): """Setup image preprocessing pipeline""" self.transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD), ]) def predict(self, image): """ Predict disease from PIL Image Args: image: PIL Image in RGB format Returns: tuple: (predicted_class_idx, confidence_score, all_probabilities) """ try: # Preprocess image input_tensor = self.transform(image).unsqueeze(0).to(self.device) # Run inference with torch.no_grad(): outputs = self.model(input_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) confidence, predicted_idx = torch.max(probabilities, 1) return ( predicted_idx.item(), confidence.item() * 100, # Convert to percentage probabilities[0] ) except Exception as e: raise RuntimeError(f"Prediction failed: {e}")