remove cuda().
Browse files
Transformer-Explainability/baselines/ViT/ViT_explanation_generator.py
CHANGED
|
@@ -50,7 +50,7 @@ class LRP:
|
|
| 50 |
one_hot[0, index] = 1
|
| 51 |
one_hot_vector = one_hot
|
| 52 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 53 |
-
one_hot = torch.sum(one_hot
|
| 54 |
|
| 55 |
self.model.zero_grad()
|
| 56 |
one_hot.backward(retain_graph=True)
|
|
@@ -70,14 +70,14 @@ class Baselines:
|
|
| 70 |
self.model.eval()
|
| 71 |
|
| 72 |
def generate_cam_attn(self, input, index=None):
|
| 73 |
-
output = self.model(input
|
| 74 |
if index == None:
|
| 75 |
index = np.argmax(output.cpu().data.numpy())
|
| 76 |
|
| 77 |
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
| 78 |
one_hot[0][index] = 1
|
| 79 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 80 |
-
one_hot = torch.sum(one_hot
|
| 81 |
|
| 82 |
self.model.zero_grad()
|
| 83 |
one_hot.backward(retain_graph=True)
|
|
|
|
| 50 |
one_hot[0, index] = 1
|
| 51 |
one_hot_vector = one_hot
|
| 52 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 53 |
+
one_hot = torch.sum(one_hot * output)
|
| 54 |
|
| 55 |
self.model.zero_grad()
|
| 56 |
one_hot.backward(retain_graph=True)
|
|
|
|
| 70 |
self.model.eval()
|
| 71 |
|
| 72 |
def generate_cam_attn(self, input, index=None):
|
| 73 |
+
output = self.model(input, register_hook=True)
|
| 74 |
if index == None:
|
| 75 |
index = np.argmax(output.cpu().data.numpy())
|
| 76 |
|
| 77 |
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
| 78 |
one_hot[0][index] = 1
|
| 79 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 80 |
+
one_hot = torch.sum(one_hot * output)
|
| 81 |
|
| 82 |
self.model.zero_grad()
|
| 83 |
one_hot.backward(retain_graph=True)
|