Spaces:
Runtime error
Runtime error
| import logging | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| from typing import Union, List | |
| import torch | |
| import torchvision | |
| def check_is_valid_torchvision_architecture(architecture: str): | |
| """Raises an ValueError if architecture is not part of available torchvision models | |
| """ | |
| available = sorted( | |
| name | |
| for name in torchvision.models.__dict__ | |
| if name.islower() | |
| and not name.startswith("__") | |
| and callable(torchvision.models.__dict__[name]) | |
| ) | |
| if architecture not in available: | |
| raise ValueError(f"{architecture} not in {available}") | |
| def build_base_model(arch: str): | |
| model = torchvision.models.__dict__[arch](pretrained=True) | |
| # get input dimension before classification layer | |
| if arch in ["mobilenet_v2"]: | |
| nfeatures = model.classifier[-1].in_features | |
| model = torch.nn.Sequential(*list(model.children())[:-1]) | |
| elif arch in ["densenet121", "densenet161", "densenet169"]: | |
| nfeatures = model.classifier.in_features | |
| model = torch.nn.Sequential(*list(model.children())[:-1]) | |
| elif "resne" in arch: | |
| # usually all ResNet variants | |
| nfeatures = model.fc.in_features | |
| model = torch.nn.Sequential(*list(model.children())[:-2]) | |
| else: | |
| raise NotImplementedError | |
| model.avgpool = torch.nn.AdaptiveAvgPool2d(1) | |
| model.flatten = torch.nn.Flatten(start_dim=1) | |
| return model, nfeatures | |
| def load_weights_if_available( | |
| model: torch.nn.Module, classifier: torch.nn.Module, weights_path: Union[str, Path] | |
| ): | |
| checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage) | |
| state_dict_features = OrderedDict() | |
| state_dict_classifier = OrderedDict() | |
| for k, w in checkpoint["state_dict"].items(): | |
| if k.startswith("model"): | |
| state_dict_features[k.replace("model.", "")] = w | |
| elif k.startswith("classifier"): | |
| state_dict_classifier[k.replace("classifier.", "")] = w | |
| else: | |
| logging.warning(f"Unexpected prefix in state_dict: {k}") | |
| model.load_state_dict(state_dict_features, strict=True) | |
| return model, classifier | |
| def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt): | |
| R = 6371 | |
| factor_rad = 0.01745329252 | |
| longitudes = factor_rad * longitudes | |
| longitudes_gt = factor_rad * longitudes_gt | |
| latitudes = factor_rad * latitudes | |
| latitudes_gt = factor_rad * latitudes_gt | |
| delta_long = longitudes_gt - longitudes | |
| delta_lat = latitudes_gt - latitudes | |
| subterm0 = torch.sin(delta_lat / 2) ** 2 | |
| subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt) | |
| subterm2 = torch.sin(delta_long / 2) ** 2 | |
| subterm1 = subterm1 * subterm2 | |
| a = subterm0 + subterm1 | |
| c = 2 * torch.asin(torch.sqrt(a)) | |
| gcd = R * c | |
| return gcd | |
| def gcd_threshold_eval(gc_dists, thresholds=[1, 25, 200, 750, 2500]): | |
| # calculate accuracy for given gcd thresolds | |
| results = {} | |
| for thres in thresholds: | |
| results[thres] = torch.true_divide( | |
| torch.sum(gc_dists <= thres), len(gc_dists) | |
| ).item() | |
| return results | |
| def accuracy(output, target, partitioning_shortnames: list, topk=(1, 5, 10)): | |
| def _accuracy(output, target, topk=(1,)): | |
| """Computes the accuracy over the k top predictions for the specified values of k""" | |
| with torch.no_grad(): | |
| maxk = max(topk) | |
| batch_size = target.size(0) | |
| _, pred = output.topk(maxk, 1, True, True) | |
| pred = pred.t() | |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
| res = {} | |
| for k in topk: | |
| correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | |
| res[k] = correct_k / batch_size | |
| return res | |
| with torch.no_grad(): | |
| out_dict = {} | |
| for i, pname in enumerate(partitioning_shortnames): | |
| res_dict = _accuracy(output[i], target[i], topk=topk) | |
| for k, v in res_dict.items(): | |
| out_dict[f"acc{k}_val/{pname}"] = v | |
| return out_dict | |
| def summarize_gcd_stats(pnames: List[str], outputs, hierarchy=None): | |
| gcd_dict = {} | |
| metric_names = [f"gcd_{p}_val" for p in pnames] | |
| if hierarchy is not None: | |
| metric_names.append("gcd_hierarchy_val") | |
| for metric_name in metric_names: | |
| distances_flat = [output[metric_name] for output in outputs] | |
| distances_flat = torch.cat(distances_flat, dim=0) | |
| gcd_results = gcd_threshold_eval(distances_flat) | |
| for gcd_thres, acc in gcd_results.items(): | |
| gcd_dict[f"{metric_name}/{gcd_thres}"] = acc | |
| return gcd_dict | |
| def summarize_test_gcd(pnames, outputs, hierarchy=None): | |
| def _eval(output): | |
| # calculate acc@km for a list of given thresholds | |
| accuracy_outputs = {} | |
| if hierarchy is not None: | |
| pnames.append("hierarchy") | |
| for pname in pnames: | |
| # concat batches of distances | |
| distances_flat = torch.cat([x[pname] for x in output], dim=0) | |
| # acc for all distances | |
| acc_dict = gcd_threshold_eval(distances_flat) | |
| accuracy_outputs[f"acc_test/{pname}"] = acc_dict | |
| return accuracy_outputs | |
| result = {} | |
| if isinstance(outputs[0], dict): # only one testset | |
| result = _eval(outputs) | |
| elif isinstance(outputs[0], list): # multiple testsets | |
| for testset_index, output in enumerate(outputs): | |
| result[testset_index] = _eval(output) | |
| else: | |
| raise TypeError | |
| return result | |
| def summarize_loss_acc_stats(pnames: List[str], outputs, topk=[1, 5, 10]): | |
| loss_acc_dict = {} | |
| metric_names = [] | |
| for k in topk: | |
| accuracy_names = [f"acc{k}_val/{p}" for p in pnames] | |
| metric_names.extend(accuracy_names) | |
| metric_names.extend([f"loss_val/{p}" for p in pnames]) | |
| for metric_name in ["loss_val/total", *metric_names]: | |
| metric_total = 0 | |
| for output in outputs: | |
| metric_value = output[metric_name] | |
| metric_total += metric_value | |
| loss_acc_dict[metric_name] = metric_total / len(outputs) | |
| return loss_acc_dict |