|
|
""" |
|
|
Grad-CAM implementation for explainable AI |
|
|
Shows which parts of the image influenced the model's decision |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
|
|
|
class GradCAM: |
|
|
""" |
|
|
Grad-CAM for VGG16 to visualize model decisions |
|
|
""" |
|
|
|
|
|
def __init__(self, model, target_layer): |
|
|
""" |
|
|
Initialize Grad-CAM |
|
|
|
|
|
Args: |
|
|
model: Trained VGG16 model |
|
|
target_layer: Layer to visualize (e.g., model.features[-1]) |
|
|
""" |
|
|
self.model = model |
|
|
self.target_layer = target_layer |
|
|
self.gradients = None |
|
|
self.activations = None |
|
|
|
|
|
|
|
|
self.target_layer.register_forward_hook(self.save_activation) |
|
|
self.target_layer.register_backward_hook(self.save_gradient) |
|
|
|
|
|
def save_activation(self, module, input, output): |
|
|
"""Save forward activation""" |
|
|
self.activations = output.detach() |
|
|
|
|
|
def save_gradient(self, module, grad_input, grad_output): |
|
|
"""Save backward gradient""" |
|
|
self.gradients = grad_output[0].detach() |
|
|
|
|
|
def generate_cam(self, input_tensor, class_idx): |
|
|
""" |
|
|
Generate Class Activation Map |
|
|
|
|
|
Args: |
|
|
input_tensor: Preprocessed input image tensor |
|
|
class_idx: Target class index |
|
|
|
|
|
Returns: |
|
|
numpy array: Heatmap (0-255) |
|
|
""" |
|
|
|
|
|
output = self.model(input_tensor) |
|
|
|
|
|
|
|
|
self.model.zero_grad() |
|
|
|
|
|
|
|
|
class_score = output[0, class_idx] |
|
|
class_score.backward() |
|
|
|
|
|
|
|
|
gradients = self.gradients[0] |
|
|
activations = self.activations[0] |
|
|
|
|
|
|
|
|
weights = gradients.mean(dim=(1, 2), keepdim=True) |
|
|
|
|
|
|
|
|
cam = (weights * activations).sum(dim=0) |
|
|
|
|
|
|
|
|
cam = F.relu(cam) |
|
|
|
|
|
|
|
|
cam = cam - cam.min() |
|
|
cam = cam / cam.max() |
|
|
|
|
|
|
|
|
cam = cam.cpu().numpy() |
|
|
cam = cv2.resize(cam, (224, 224)) |
|
|
|
|
|
|
|
|
cam = np.uint8(255 * cam) |
|
|
|
|
|
return cam |
|
|
|
|
|
def overlay_heatmap(self, image, heatmap, alpha=0.5): |
|
|
""" |
|
|
Overlay heatmap on original image |
|
|
|
|
|
Args: |
|
|
image: Original PIL Image |
|
|
heatmap: Heatmap numpy array (0-255) |
|
|
alpha: Transparency (0-1) |
|
|
|
|
|
Returns: |
|
|
PIL Image with heatmap overlay |
|
|
""" |
|
|
|
|
|
image = image.resize((224, 224)) |
|
|
image_np = np.array(image) |
|
|
|
|
|
|
|
|
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
|
|
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
overlayed = cv2.addWeighted(image_np, 1-alpha, heatmap_colored, alpha, 0) |
|
|
|
|
|
return Image.fromarray(overlayed) |
|
|
|
|
|
|
|
|
def generate_gradcam_visualization(model, image, predicted_class_idx, transform): |
|
|
""" |
|
|
Generate Grad-CAM visualization for a prediction |
|
|
|
|
|
Args: |
|
|
model: Trained VGG16 model |
|
|
image: PIL Image |
|
|
predicted_class_idx: Index of predicted class |
|
|
transform: Image transformation pipeline |
|
|
|
|
|
Returns: |
|
|
PIL Image: Image with heatmap overlay |
|
|
""" |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
input_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
gradcam = GradCAM(model, model.features[-1]) |
|
|
|
|
|
|
|
|
heatmap = gradcam.generate_cam(input_tensor, predicted_class_idx) |
|
|
|
|
|
|
|
|
visualization = gradcam.overlay_heatmap(image, heatmap, alpha=0.4) |
|
|
|
|
|
return visualization |
|
|
|