File size: 3,188 Bytes
fbbdeab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Model loader - Handles VGG16 loading and inference
"""

import torch
import torch.nn as nn
from torchvision import models, transforms
from .config import IMAGE_SIZE, NORMALIZE_MEAN, NORMALIZE_STD

class VGG16ModelLoader:
    """Loads and manages VGG16 model for fish disease detection"""
    
    def __init__(self, model_path, num_classes, device='cpu'):
        """
        Initialize model loader
        
        Args:
            model_path: Path to trained model weights (.pth file)
            num_classes: Number of disease classes
            device: 'cpu' or 'cuda'
        """
        self.model_path = model_path
        self.num_classes = num_classes
        self.device = torch.device(device)
        self.model = None
        self.transform = None
        
        self._load_model()
        self._setup_transform()
    
    def _load_model(self):
        """Load VGG16 with custom classifier"""
        try:
            # Create VGG16 architecture
            self.model = models.vgg16(weights="IMAGENET1K_V1")
            
            # Replace final layer for our classes
            self.model.classifier[6] = nn.Linear(4096, self.num_classes)
            
            # Load trained weights
            state_dict = torch.load(
                self.model_path, 
                map_location=self.device,
                weights_only=True  # Security: only load weights
            )
            self.model.load_state_dict(state_dict)
            
            # Move to device and set eval mode
            self.model = self.model.to(self.device)
            self.model.eval()
            
            print(f"βœ… Model loaded: {self.model_path}")
            print(f"βœ… Device: {self.device}")
            
        except FileNotFoundError:
            raise RuntimeError(f"Model file not found: {self.model_path}")
        except Exception as e:
            raise RuntimeError(f"Failed to load model: {e}")
    
    def _setup_transform(self):
        """Setup image preprocessing pipeline"""
        self.transform = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
        ])
    
    def predict(self, image):
        """
        Predict disease from PIL Image
        
        Args:
            image: PIL Image in RGB format
            
        Returns:
            tuple: (predicted_class_idx, confidence_score, all_probabilities)
        """
        try:
            # Preprocess image
            input_tensor = self.transform(image).unsqueeze(0).to(self.device)
            
            # Run inference
            with torch.no_grad():
                outputs = self.model(input_tensor)
                probabilities = torch.nn.functional.softmax(outputs, dim=1)
                confidence, predicted_idx = torch.max(probabilities, 1)
            
            return (
                predicted_idx.item(),
                confidence.item() * 100,  # Convert to percentage
                probabilities[0]
            )
            
        except Exception as e:
            raise RuntimeError(f"Prediction failed: {e}")