Spaces:
Sleeping
Sleeping
| print("Importing standard...") | |
| from abc import ABC, abstractmethod | |
| print("Importing external...") | |
| import torch | |
| from torch.nn.functional import binary_cross_entropy | |
| # from matplotlib import pyplot as plt | |
| print("Importing internal...") | |
| from utils import preprocess_masks_features, get_row_col, symlog, calculate_iou | |
| ######### BINARY LOSSES ############### | |
| def my_lovasz_hinge(logits, gt, downsample=False): | |
| if downsample: | |
| offset = int(torch.randint(downsample - 1, (1,))) | |
| logits, gt = logits[:, offset::downsample], gt[:, offset::downsample] | |
| # B, HW | |
| gt = 1.0 * gt # go float | |
| areas = gt.sum(dim=1, keepdims=True) # B, 1 | |
| # per_image = True, ignore = None | |
| signs = 2 * gt - 1 | |
| errors = 1 - logits * signs | |
| errors_sorted, perm = torch.sort(errors, dim=1, descending=True) | |
| gt_sorted = torch.gather(gt, 1, perm) # B, HW | |
| # lovasz grad | |
| intersection = areas - gt_sorted.cumsum(dim=1) # B, HW | |
| union = areas + (1 - gt_sorted).cumsum(dim=1) # B, HW | |
| jaccard = 1 - intersection / union # B, HW | |
| jaccard[:, 1:] = jaccard[:, 1:] - jaccard[:, :-1] | |
| loss = (torch.relu(errors_sorted) * jaccard).sum(dim=1) # B, | |
| return torch.nanmean(loss) | |
| def focal_loss(scores, targets, alpha=0.25, gamma=2): | |
| p = scores | |
| ce_loss = binary_cross_entropy(p, targets, reduction="none") | |
| p_t = p * targets + (1 - p) * (1 - targets) | |
| loss = ce_loss * ((1 - p_t) ** gamma) | |
| if alpha >= 0: | |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
| loss = alpha_t * loss | |
| return loss | |
| # also binary_cross_entropy and lovasz | |
| ########## SUBFUNCTIONS ######################3 | |
| def get_distances(features, refs, sigma, norm_p, square_distances, H, W): | |
| # features: B, 1, F, HW | |
| # refs: B, M, F, 1 | |
| # sigma: B, M, 1, 1 | |
| B, M = refs.shape[0], refs.shape[1] | |
| distances = torch.norm( | |
| features - refs, dim=2, p=norm_p, keepdim=True | |
| ) # B, M, 1, H*W | |
| distances = distances**2 if square_distances else distances | |
| distances = (distances / (2 * sigma**2)).reshape(B, M, H * W) | |
| return distances | |
| def activate(features, masks, activation, use_sigma, offset_pos, ret_prediction): | |
| # sigmoid is very similar to exp | |
| # prepare features | |
| assert activation in ["sigmoid", "symlog"] | |
| if masks is None: # when inferencing | |
| B, M = 1, 1 | |
| F, N = sorted(features.shape) | |
| H, W = [int(N ** (0.5))] * 2 | |
| features = features.reshape(1, 1, -1, H * W) | |
| else: | |
| masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) | |
| # features: B, 1, F, H*W | |
| # masks: B, M, 1, H*W | |
| if use_sigma: | |
| sigma = torch.nn.functional.softplus(features)[:, :, -1:] # B, 1, 1, H*W | |
| features = features[:, :, :-1] | |
| F = features.shape[2] | |
| else: | |
| sigma = 1 | |
| features = symlog(features) if activation == "symlog" else torch.sigmoid(features) | |
| if offset_pos: | |
| assert F >= 2 | |
| row, col = get_row_col(H, W, features.device) | |
| row = row.reshape(1, 1, 1, H, 1).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) | |
| col = col.reshape(1, 1, 1, 1, W).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) | |
| positional_features = torch.cat([row, col], dim=2) # B, 1, 2, H*W | |
| features[:, :, :2] = features[:, :, :2] + positional_features | |
| prediction = features.reshape(B, 1, -1, H, W) if ret_prediction else None | |
| if masks is None: | |
| features = features.reshape(-1, H * W) | |
| sigma = sigma.reshape(-1, H * W) if use_sigma else 1 | |
| return features, sigma, H, W | |
| return features, masks, sigma, prediction, B, M, F, H, W | |
| class AbstractLoss(ABC): | |
| def loss(features, masks, ret_prediction=False, **kwargs): | |
| pass | |
| def get_mask_from_query(features, sindex, **kwargs): | |
| pass | |
| class IISLoss(AbstractLoss): | |
| def loss(features, masks, ret_prediction=False, K=3, logger=None): | |
| features, masks, sigma, prediction, B, M, F, H, W = activate( | |
| features, masks, "symlog", False, False, ret_prediction | |
| ) | |
| rindices = torch.randperm(H * W, device=masks.device) | |
| # the following should work if all masks have more than K pixels | |
| sindices = torch.stack( | |
| [ | |
| torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) | |
| for b in range(B) | |
| ] | |
| ) # B, M, K | |
| feats_at_sindices = torch.gather( | |
| features.permute(0, 3, 1, 2).expand(B, H * W, K, F), | |
| dim=1, | |
| index=sindices.reshape(B, M, K, 1).expand(B, M, K, F), | |
| ) # B, M, K, F | |
| feats_at_sindices = feats_at_sindices.reshape(B, M, K, F, 1) # B, M, K, F, 1 | |
| dists = get_distances( | |
| features, feats_at_sindices.reshape(B, M * K, F, 1), sigma, 2, True, H, W | |
| ) | |
| score = torch.exp(-dists) # B, M*K, H*W [0, 1] | |
| targets = ( | |
| masks.expand(B, M, K, H * W).reshape(B, M * K, H * W).float() | |
| ) # B, M, K, H*W | |
| floss = focal_loss(score, targets).mean() | |
| lloss = my_lovasz_hinge( | |
| score.view(B * M * K, H * W) * 2 - 1, | |
| targets.view(B * M * K, H * W), | |
| ) | |
| loss = floss + lloss | |
| return loss, prediction | |
| def get_mask_from_query(features, sindex): | |
| features, _, H, W = activate(features, None, "symlog", False, False, False) | |
| F = features.shape[0] | |
| query_feat = features[:, sindex] | |
| dists = get_distances( | |
| features.reshape(1, 1, F, H * W), | |
| query_feat.reshape(1, 1, F, 1), | |
| 1, | |
| 2, | |
| True, | |
| H, | |
| W, | |
| ) | |
| score = torch.exp(-dists) # 1, H*W | |
| pred = score > 0.5 | |
| return pred | |
| def iis_iou(features, masks, get_mask_from_query, K=20): | |
| masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) | |
| # features: B, 1, F, H*W | |
| # masks: B, M, 1, H*W | |
| rindices = torch.randperm(H * W).to(masks.device) | |
| sindices = torch.stack( | |
| [ | |
| torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) | |
| for b in range(B) | |
| ] | |
| ) # B, M, K | |
| cum_iou, n_samples = 0, 0 | |
| for b in range(B): | |
| for m in range(M): | |
| for k in range(K): | |
| sindex = sindices[b, m, k] | |
| pred = get_mask_from_query(features[b, 0], sindex) | |
| iou = calculate_iou(pred, masks[b, m, 0, :]) | |
| cum_iou += iou | |
| n_samples += 1 | |
| return cum_iou / n_samples | |
| losses_names = [ | |
| "iis", | |
| ] | |
| # | |
| def get_loss_class(loss_name): | |
| if loss_name == "iis": | |
| return IISLoss | |
| else: | |
| raise NotImplementedError | |
| def get_get_mask_from_query(loss_name): | |
| loss_class = get_loss_class(loss_name) | |
| return loss_class.get_mask_from_query | |
| def get_loss(loss_name): | |
| loss_class = get_loss_class(loss_name) | |
| return loss_class.loss | |