remove cuda
Browse files- generic_utils.py +5 -5
generic_utils.py
CHANGED
|
@@ -24,10 +24,10 @@ def show_cam_on_image(img, mask):
|
|
| 24 |
|
| 25 |
|
| 26 |
# initialize ViT pretrained
|
| 27 |
-
model = vit_LRP(pretrained=True)
|
| 28 |
model.eval()
|
| 29 |
attribution_generator = LRP(model)
|
| 30 |
-
model_baseline = vit(pretrained=True)
|
| 31 |
model_baseline.eval()
|
| 32 |
baselines_generator = Baselines(model_baseline)
|
| 33 |
|
|
@@ -37,16 +37,16 @@ def generate_visualization(
|
|
| 37 |
):
|
| 38 |
if LRP:
|
| 39 |
transformer_attribution = attribution_generator.generate_LRP(
|
| 40 |
-
original_image.unsqueeze(0)
|
| 41 |
).detach()
|
| 42 |
else:
|
| 43 |
if method == "gradcam":
|
| 44 |
transformer_attribution = baselines_generator.generate_cam_attn(
|
| 45 |
-
original_image.unsqueeze(0)
|
| 46 |
).detach()
|
| 47 |
else:
|
| 48 |
transformer_attribution = baselines_generator.generate_rollout(
|
| 49 |
-
original_image.unsqueeze(0)
|
| 50 |
).detach()
|
| 51 |
if method != "full":
|
| 52 |
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
# initialize ViT pretrained
|
| 27 |
+
model = vit_LRP(pretrained=True)
|
| 28 |
model.eval()
|
| 29 |
attribution_generator = LRP(model)
|
| 30 |
+
model_baseline = vit(pretrained=True)
|
| 31 |
model_baseline.eval()
|
| 32 |
baselines_generator = Baselines(model_baseline)
|
| 33 |
|
|
|
|
| 37 |
):
|
| 38 |
if LRP:
|
| 39 |
transformer_attribution = attribution_generator.generate_LRP(
|
| 40 |
+
original_image.unsqueeze(0), method=method, index=class_index
|
| 41 |
).detach()
|
| 42 |
else:
|
| 43 |
if method == "gradcam":
|
| 44 |
transformer_attribution = baselines_generator.generate_cam_attn(
|
| 45 |
+
original_image.unsqueeze(0), index=class_index
|
| 46 |
).detach()
|
| 47 |
else:
|
| 48 |
transformer_attribution = baselines_generator.generate_rollout(
|
| 49 |
+
original_image.unsqueeze(0)
|
| 50 |
).detach()
|
| 51 |
if method != "full":
|
| 52 |
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
|