mathaisjustin's picture
Deploy Fish Disease Detection AI
fbbdeab
"""
Grad-CAM implementation for explainable AI
Shows which parts of the image influenced the model's decision
"""
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
class GradCAM:
"""
Grad-CAM for VGG16 to visualize model decisions
"""
def __init__(self, model, target_layer):
"""
Initialize Grad-CAM
Args:
model: Trained VGG16 model
target_layer: Layer to visualize (e.g., model.features[-1])
"""
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
# Register hooks
self.target_layer.register_forward_hook(self.save_activation)
self.target_layer.register_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
"""Save forward activation"""
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
"""Save backward gradient"""
self.gradients = grad_output[0].detach()
def generate_cam(self, input_tensor, class_idx):
"""
Generate Class Activation Map
Args:
input_tensor: Preprocessed input image tensor
class_idx: Target class index
Returns:
numpy array: Heatmap (0-255)
"""
# Forward pass
output = self.model(input_tensor)
# Zero gradients
self.model.zero_grad()
# Backward pass for target class
class_score = output[0, class_idx]
class_score.backward()
# Get gradients and activations
gradients = self.gradients[0] # [C, H, W]
activations = self.activations[0] # [C, H, W]
# Calculate weights (global average pooling of gradients)
weights = gradients.mean(dim=(1, 2), keepdim=True) # [C, 1, 1]
# Weighted combination of activation maps
cam = (weights * activations).sum(dim=0) # [H, W]
# Apply ReLU (only positive influence)
cam = F.relu(cam)
# Normalize to 0-1
cam = cam - cam.min()
cam = cam / cam.max()
# Convert to numpy and resize to 224x224
cam = cam.cpu().numpy()
cam = cv2.resize(cam, (224, 224))
# Convert to 0-255 range
cam = np.uint8(255 * cam)
return cam
def overlay_heatmap(self, image, heatmap, alpha=0.5):
"""
Overlay heatmap on original image
Args:
image: Original PIL Image
heatmap: Heatmap numpy array (0-255)
alpha: Transparency (0-1)
Returns:
PIL Image with heatmap overlay
"""
# Resize image to 224x224
image = image.resize((224, 224))
image_np = np.array(image)
# Apply colormap to heatmap (red = high activation)
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
# Overlay
overlayed = cv2.addWeighted(image_np, 1-alpha, heatmap_colored, alpha, 0)
return Image.fromarray(overlayed)
def generate_gradcam_visualization(model, image, predicted_class_idx, transform):
"""
Generate Grad-CAM visualization for a prediction
Args:
model: Trained VGG16 model
image: PIL Image
predicted_class_idx: Index of predicted class
transform: Image transformation pipeline
Returns:
PIL Image: Image with heatmap overlay
"""
# Set model to eval mode
model.eval()
# Preprocess image
input_tensor = transform(image).unsqueeze(0)
# Create Grad-CAM instance (target last convolutional layer of VGG16)
gradcam = GradCAM(model, model.features[-1])
# Generate heatmap
heatmap = gradcam.generate_cam(input_tensor, predicted_class_idx)
# Overlay on original image
visualization = gradcam.overlay_heatmap(image, heatmap, alpha=0.4)
return visualization