comparative-explainability
/
Transformer-Explainability
/baselines
/ViT
/pertubation_eval_from_hdf5.py
| import argparse | |
| # from models.vgg import vgg19 | |
| import glob | |
| import os | |
| import numpy as np | |
| import torch | |
| from dataset.expl_hdf5 import ImagenetResults | |
| from tqdm import tqdm | |
| # Import saliency methods and models | |
| from ViT_explanation_generator import Baselines | |
| from ViT_new import vit_base_patch16_224 | |
| def normalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): | |
| dtype = tensor.dtype | |
| mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) | |
| std = torch.as_tensor(std, dtype=dtype, device=tensor.device) | |
| tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) | |
| return tensor | |
| def eval(args): | |
| num_samples = 0 | |
| num_correct_model = np.zeros( | |
| ( | |
| len( | |
| imagenet_ds, | |
| ) | |
| ) | |
| ) | |
| dissimilarity_model = np.zeros( | |
| ( | |
| len( | |
| imagenet_ds, | |
| ) | |
| ) | |
| ) | |
| model_index = 0 | |
| if args.scale == "per": | |
| base_size = 224 * 224 | |
| perturbation_steps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] | |
| elif args.scale == "100": | |
| base_size = 100 | |
| perturbation_steps = [5, 10, 15, 20, 25, 30, 35, 40, 45] | |
| else: | |
| raise Exception("scale not valid") | |
| num_correct_pertub = np.zeros((9, len(imagenet_ds))) | |
| dissimilarity_pertub = np.zeros((9, len(imagenet_ds))) | |
| logit_diff_pertub = np.zeros((9, len(imagenet_ds))) | |
| prob_diff_pertub = np.zeros((9, len(imagenet_ds))) | |
| perturb_index = 0 | |
| for batch_idx, (data, vis, target) in enumerate(tqdm(sample_loader)): | |
| # Update the number of samples | |
| num_samples += len(data) | |
| data = data.to(device) | |
| vis = vis.to(device) | |
| target = target.to(device) | |
| norm_data = normalize(data.clone()) | |
| # Compute model accuracy | |
| pred = model(norm_data) | |
| pred_probabilities = torch.softmax(pred, dim=1) | |
| pred_org_logit = pred.data.max(1, keepdim=True)[0].squeeze(1) | |
| pred_org_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1) | |
| pred_class = pred.data.max(1, keepdim=True)[1].squeeze(1) | |
| tgt_pred = (target == pred_class).type(target.type()).data.cpu().numpy() | |
| num_correct_model[model_index : model_index + len(tgt_pred)] = tgt_pred | |
| probs = torch.softmax(pred, dim=1) | |
| target_probs = torch.gather(probs, 1, target[:, None])[:, 0] | |
| second_probs = probs.data.topk(2, dim=1)[0][:, 1] | |
| temp = torch.log(target_probs / second_probs).data.cpu().numpy() | |
| dissimilarity_model[model_index : model_index + len(temp)] = temp | |
| if args.wrong: | |
| wid = np.argwhere(tgt_pred == 0).flatten() | |
| if len(wid) == 0: | |
| continue | |
| wid = torch.from_numpy(wid).to(vis.device) | |
| vis = vis.index_select(0, wid) | |
| data = data.index_select(0, wid) | |
| target = target.index_select(0, wid) | |
| # Save original shape | |
| org_shape = data.shape | |
| if args.neg: | |
| vis = -vis | |
| vis = vis.reshape(org_shape[0], -1) | |
| for i in range(len(perturbation_steps)): | |
| _data = data.clone() | |
| _, idx = torch.topk(vis, int(base_size * perturbation_steps[i]), dim=-1) | |
| idx = idx.unsqueeze(1).repeat(1, org_shape[1], 1) | |
| _data = _data.reshape(org_shape[0], org_shape[1], -1) | |
| _data = _data.scatter_(-1, idx, 0) | |
| _data = _data.reshape(*org_shape) | |
| _norm_data = normalize(_data) | |
| out = model(_norm_data) | |
| pred_probabilities = torch.softmax(out, dim=1) | |
| pred_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1) | |
| diff = (pred_prob - pred_org_prob).data.cpu().numpy() | |
| prob_diff_pertub[i, perturb_index : perturb_index + len(diff)] = diff | |
| pred_logit = out.data.max(1, keepdim=True)[0].squeeze(1) | |
| diff = (pred_logit - pred_org_logit).data.cpu().numpy() | |
| logit_diff_pertub[i, perturb_index : perturb_index + len(diff)] = diff | |
| target_class = out.data.max(1, keepdim=True)[1].squeeze(1) | |
| temp = (target == target_class).type(target.type()).data.cpu().numpy() | |
| num_correct_pertub[i, perturb_index : perturb_index + len(temp)] = temp | |
| probs_pertub = torch.softmax(out, dim=1) | |
| target_probs = torch.gather(probs_pertub, 1, target[:, None])[:, 0] | |
| second_probs = probs_pertub.data.topk(2, dim=1)[0][:, 1] | |
| temp = torch.log(target_probs / second_probs).data.cpu().numpy() | |
| dissimilarity_pertub[i, perturb_index : perturb_index + len(temp)] = temp | |
| model_index += len(target) | |
| perturb_index += len(target) | |
| np.save(os.path.join(args.experiment_dir, "model_hits.npy"), num_correct_model) | |
| np.save( | |
| os.path.join(args.experiment_dir, "model_dissimilarities.npy"), | |
| dissimilarity_model, | |
| ) | |
| np.save( | |
| os.path.join(args.experiment_dir, "perturbations_hits.npy"), | |
| num_correct_pertub[:, :perturb_index], | |
| ) | |
| np.save( | |
| os.path.join(args.experiment_dir, "perturbations_dissimilarities.npy"), | |
| dissimilarity_pertub[:, :perturb_index], | |
| ) | |
| np.save( | |
| os.path.join(args.experiment_dir, "perturbations_logit_diff.npy"), | |
| logit_diff_pertub[:, :perturb_index], | |
| ) | |
| np.save( | |
| os.path.join(args.experiment_dir, "perturbations_prob_diff.npy"), | |
| prob_diff_pertub[:, :perturb_index], | |
| ) | |
| print(np.mean(num_correct_model), np.std(num_correct_model)) | |
| print(np.mean(dissimilarity_model), np.std(dissimilarity_model)) | |
| print(perturbation_steps) | |
| print(np.mean(num_correct_pertub, axis=1), np.std(num_correct_pertub, axis=1)) | |
| print(np.mean(dissimilarity_pertub, axis=1), np.std(dissimilarity_pertub, axis=1)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Train a segmentation") | |
| parser.add_argument("--batch-size", type=int, default=16, help="") | |
| parser.add_argument("--neg", type=bool, default=True, help="") | |
| parser.add_argument("--value", action="store_true", default=False, help="") | |
| parser.add_argument( | |
| "--scale", type=str, default="per", choices=["per", "100"], help="" | |
| ) | |
| parser.add_argument( | |
| "--method", | |
| type=str, | |
| default="grad_rollout", | |
| choices=[ | |
| "rollout", | |
| "lrp", | |
| "transformer_attribution", | |
| "full_lrp", | |
| "v_gradcam", | |
| "lrp_last_layer", | |
| "lrp_second_layer", | |
| "gradcam", | |
| "attn_last_layer", | |
| "attn_gradcam", | |
| "input_grads", | |
| ], | |
| help="", | |
| ) | |
| parser.add_argument( | |
| "--vis-class", | |
| type=str, | |
| default="top", | |
| choices=["top", "target", "index"], | |
| help="", | |
| ) | |
| parser.add_argument("--wrong", action="store_true", default=False, help="") | |
| parser.add_argument("--class-id", type=int, default=0, help="") | |
| parser.add_argument("--is-ablation", type=bool, default=False, help="") | |
| args = parser.parse_args() | |
| torch.multiprocessing.set_start_method("spawn") | |
| # PATH variables | |
| PATH = os.path.dirname(os.path.abspath(__file__)) + "/" | |
| dataset = PATH + "dataset/" | |
| os.makedirs(os.path.join(PATH, "experiments"), exist_ok=True) | |
| os.makedirs(os.path.join(PATH, "experiments/perturbations"), exist_ok=True) | |
| exp_name = args.method | |
| exp_name += "_neg" if args.neg else "_pos" | |
| print(exp_name) | |
| if args.vis_class == "index": | |
| args.runs_dir = os.path.join( | |
| PATH, | |
| "experiments/perturbations/{}/{}_{}".format( | |
| exp_name, args.vis_class, args.class_id | |
| ), | |
| ) | |
| else: | |
| ablation_fold = "ablation" if args.is_ablation else "not_ablation" | |
| args.runs_dir = os.path.join( | |
| PATH, | |
| "experiments/perturbations/{}/{}/{}".format( | |
| exp_name, args.vis_class, ablation_fold | |
| ), | |
| ) | |
| # args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}'.format(exp_name, | |
| # args.vis_class)) | |
| if args.wrong: | |
| args.runs_dir += "_wrong" | |
| experiments = sorted(glob.glob(os.path.join(args.runs_dir, "experiment_*"))) | |
| experiment_id = int(experiments[-1].split("_")[-1]) + 1 if experiments else 0 | |
| args.experiment_dir = os.path.join( | |
| args.runs_dir, "experiment_{}".format(str(experiment_id)) | |
| ) | |
| os.makedirs(args.experiment_dir, exist_ok=True) | |
| cuda = torch.cuda.is_available() | |
| device = torch.device("cuda" if cuda else "cpu") | |
| if args.vis_class == "index": | |
| vis_method_dir = os.path.join( | |
| PATH, | |
| "visualizations/{}/{}_{}".format( | |
| args.method, args.vis_class, args.class_id | |
| ), | |
| ) | |
| else: | |
| ablation_fold = "ablation" if args.is_ablation else "not_ablation" | |
| vis_method_dir = os.path.join( | |
| PATH, | |
| "visualizations/{}/{}/{}".format( | |
| args.method, args.vis_class, ablation_fold | |
| ), | |
| ) | |
| # vis_method_dir = os.path.join(PATH, 'visualizations/{}/{}'.format(args.method, | |
| # args.vis_class)) | |
| # imagenet_ds = ImagenetResults('visualizations/{}'.format(args.method)) | |
| imagenet_ds = ImagenetResults(vis_method_dir) | |
| # Model | |
| model = vit_base_patch16_224(pretrained=True).cuda() | |
| model.eval() | |
| save_path = PATH + "results/" | |
| sample_loader = torch.utils.data.DataLoader( | |
| imagenet_ds, batch_size=args.batch_size, num_workers=2, shuffle=False | |
| ) | |
| eval(args) | |