| import matplotlib.pyplot as plt | |
| from generic_utils import generate_visualization | |
| def do_lrp(transform, image, class_index=None): | |
| fig, axs = plt.subplots(1, 2) | |
| axs[0].imshow(image) | |
| axs[0].axis("off") | |
| transformed_image = transform(image) | |
| viz = generate_visualization( | |
| transformed_image, class_index=class_index, method="full" | |
| ) | |
| axs[1].imshow(viz) | |
| axs[1].axis("off") | |
| return fig | |
| def do_partial_lrp(transform, image, class_index=None): | |
| fig, axs = plt.subplots(1, 2) | |
| axs[0].imshow(image) | |
| axs[0].axis("off") | |
| transformed_image = transform(image) | |
| viz = generate_visualization( | |
| transformed_image, class_index=class_index, method="last_layer" | |
| ) | |
| axs[1].imshow(viz) | |
| axs[1].axis("off") | |
| return fig | |