| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from gradcam import do_gradcam | |
| from lrp import do_lrp, do_partial_lrp | |
| from rollout import do_rollout | |
| from tiba import do_tiba | |
| normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| TRANSFORM = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| METHOD_MAP = { | |
| "tiba": do_tiba, | |
| "gradcam": do_gradcam, | |
| "lrp": do_lrp, | |
| "partial_lrp": do_partial_lrp, | |
| "rollout": do_rollout, | |
| } | |
| def generate_viz(image, method, class_index=None): | |
| if class_index is not None: | |
| class_index = int(class_index) | |
| print(f"Image: {image.size}") | |
| print(f"Method: {method}") | |
| print(f"Class: {class_index}") | |
| viz_method = METHOD_MAP[method] | |
| viz = viz_method(TRANSFORM, image, class_index=class_index) | |
| viz.savefig("visualization.png") | |
| return Image.open("visualization.png").convert("RGB") | |
| title = "Compare different methods of explaining ViTs π€" | |
| article = "Different methods for explaining Vision Transformers as explored by Chefer et al. in [Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks](https://arxiv.org/abs/2012.09838)." | |
| iface = gr.Interface( | |
| generate_viz, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Dropdown( | |
| list(METHOD_MAP.keys()), | |
| label="Method", | |
| info="Explainability method to investigate.", | |
| ), | |
| gr.Number(label="Class Index", info="Class index to inspect"), | |
| ], | |
| outputs=gr.Image(), | |
| title=title, | |
| article=article, | |
| allow_flagging="never", | |
| cache_examples=True, | |
| examples=[ | |
| ["Transformer-Explainability/samples/catdog.png", "tiba", None], | |
| ["Transformer-Explainability/samples/catdog.png", "rollout", 243], | |
| ["Transformer-Explainability/samples/el2.png", "tiba", None], | |
| ["Transformer-Explainability/samples/el2.png", "gradcam", 340], | |
| ["Transformer-Explainability/samples/dogbird.png", "lrp", 161], | |
| ], | |
| ) | |
| iface.launch(debug=True) | |