|
|
""" |
|
|
Model loader - Handles VGG16 loading and inference |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models, transforms |
|
|
from .config import IMAGE_SIZE, NORMALIZE_MEAN, NORMALIZE_STD |
|
|
|
|
|
class VGG16ModelLoader: |
|
|
"""Loads and manages VGG16 model for fish disease detection""" |
|
|
|
|
|
def __init__(self, model_path, num_classes, device='cpu'): |
|
|
""" |
|
|
Initialize model loader |
|
|
|
|
|
Args: |
|
|
model_path: Path to trained model weights (.pth file) |
|
|
num_classes: Number of disease classes |
|
|
device: 'cpu' or 'cuda' |
|
|
""" |
|
|
self.model_path = model_path |
|
|
self.num_classes = num_classes |
|
|
self.device = torch.device(device) |
|
|
self.model = None |
|
|
self.transform = None |
|
|
|
|
|
self._load_model() |
|
|
self._setup_transform() |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load VGG16 with custom classifier""" |
|
|
try: |
|
|
|
|
|
self.model = models.vgg16(weights="IMAGENET1K_V1") |
|
|
|
|
|
|
|
|
self.model.classifier[6] = nn.Linear(4096, self.num_classes) |
|
|
|
|
|
|
|
|
state_dict = torch.load( |
|
|
self.model_path, |
|
|
map_location=self.device, |
|
|
weights_only=True |
|
|
) |
|
|
self.model.load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
print(f"β
Model loaded: {self.model_path}") |
|
|
print(f"β
Device: {self.device}") |
|
|
|
|
|
except FileNotFoundError: |
|
|
raise RuntimeError(f"Model file not found: {self.model_path}") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load model: {e}") |
|
|
|
|
|
def _setup_transform(self): |
|
|
"""Setup image preprocessing pipeline""" |
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD), |
|
|
]) |
|
|
|
|
|
def predict(self, image): |
|
|
""" |
|
|
Predict disease from PIL Image |
|
|
|
|
|
Args: |
|
|
image: PIL Image in RGB format |
|
|
|
|
|
Returns: |
|
|
tuple: (predicted_class_idx, confidence_score, all_probabilities) |
|
|
""" |
|
|
try: |
|
|
|
|
|
input_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(input_tensor) |
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
|
confidence, predicted_idx = torch.max(probabilities, 1) |
|
|
|
|
|
return ( |
|
|
predicted_idx.item(), |
|
|
confidence.item() * 100, |
|
|
probabilities[0] |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Prediction failed: {e}") |
|
|
|