""" Modified from DETR https://github.com/facebookresearch/detr """ from typing import List import torch import torch.nn as nn import torch.nn.functional as F from .utils import MLP from torch import Tensor from einops import rearrange, repeat class FPNSpatialDecoder(nn.Module): """ An FPN-like spatial decoder. Generates high-res, semantically rich features which serve as the base for creating instance segmentation masks. """ def __init__(self, context_dim, fpn_dims, mask_kernels_dim=8): super().__init__() # from low to high inter_dims = [context_dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16] self.lay1 = torch.nn.Conv2d(context_dim, inter_dims[0], 3, padding=1) self.gn1 = torch.nn.GroupNorm(8, inter_dims[0]) self.lay2 = torch.nn.Conv2d(inter_dims[0], inter_dims[1], 3, padding=1) self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) self.context_dim = context_dim self.mask_dim = mask_kernels_dim self.add_extra_layer = len(fpn_dims) == 3 if self.add_extra_layer: self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) self.out_lay = torch.nn.Conv2d(inter_dims[4], mask_kernels_dim, 3, padding=1) else: self.out_lay = torch.nn.Conv2d(inter_dims[3], mask_kernels_dim, 3, padding=1) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_uniform_(m.weight, a=1) nn.init.constant_(m.bias, 0) def forward(self, x: Tensor, layer_features: List[Tensor]): x = self.lay1(x) x = self.gn1(x) x = F.relu(x) x = self.lay2(x) x = self.gn2(x) x = F.relu(x) cur_fpn = self.adapter1(layer_features[0]) x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") x = self.lay3(x) x = self.gn3(x) x = F.relu(x) cur_fpn = self.adapter2(layer_features[1]) x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") x = self.lay4(x) x = self.gn4(x) x = F.relu(x) if self.add_extra_layer: cur_fpn = self.adapter3(layer_features[2]) x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") x = self.lay5(x) x = self.gn5(x) x = F.relu(x) x = self.out_lay(x) return x def num_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) class DynamicSegmentationHead(nn.Module): def __init__(self, hidden_dim, config): super().__init__() self.mask_out_stride = 4 self.mask_feat_stride = 4 n_layers = config.controller_layers in_channels = config.mask_dim self.dynamic_mask_channels = config.dynamic_mask_channels self.rel_coord = config.rel_coord weight_nums, bias_nums = [], [] for l in range(n_layers): if l == 0: if self.rel_coord: weight_nums.append((in_channels + 2) * self.dynamic_mask_channels) else: weight_nums.append(in_channels * self.dynamic_mask_channels) bias_nums.append(self.dynamic_mask_channels) elif l == n_layers - 1: weight_nums.append(self.dynamic_mask_channels * 1) # output layer c -> 1 bias_nums.append(1) else: weight_nums.append(self.dynamic_mask_channels * self.dynamic_mask_channels) bias_nums.append(self.dynamic_mask_channels) self.weight_nums = weight_nums self.bias_nums = bias_nums num_gen_params = sum(weight_nums) + sum(bias_nums) self.controller = MLP(hidden_dim, hidden_dim, num_gen_params, 1) for layer in self.controller.layers: if isinstance(layer, nn.Linear): nn.init.zeros_(layer.bias) nn.init.xavier_uniform_(layer.weight) def forward(self, hs, mask_features, references, targets): T = len(targets[0]['frames_idx']) mask_features = rearrange(mask_features, '(b t) c h w -> b t c h w', t=T) dynamic_mask_head_params = self.controller(hs) # [B*T, Q, num_params] dynamic_mask_head_params = rearrange(dynamic_mask_head_params, '(b t) q n -> b (t q) n', t=T) lvl_references = references[..., :2] lvl_references = rearrange(lvl_references, '(b t) q n -> b (t q) n', t=T) outputs_seg_mask = self.dynamic_mask_with_coords(mask_features, dynamic_mask_head_params, lvl_references, targets) outputs_seg_mask = rearrange(outputs_seg_mask, 'b (t q) h w -> (b t) q h w', t=T) return outputs_seg_mask def dynamic_mask_with_coords(self, mask_features, mask_head_params, reference_points, targets): """ Add the relative coordinates to the mask_features channel dimension, and perform dynamic mask conv. Args: mask_features: [batch_size, time, c, h, w] mask_head_params: [batch_size, time * num_queries_per_frame, num_params] reference_points: [batch_size, time * num_queries_per_frame, 2], cxcy targets (list[dict]): length is batch size we need the key 'size' for computing location. Return: outputs_seg_mask: [batch_size, time * num_queries_per_frame, h, w] """ device = mask_features.device b, t, c, h, w = mask_features.shape # this is the total query number in all frames _, num_queries = reference_points.shape[:2] q = num_queries // t # num_queries_per_frame # prepare reference points in image size (the size is input size to the model) new_reference_points = [] for i in range(b): img_h, img_w = targets[i]['size'] scale_f = torch.stack([img_w, img_h], dim=0) tmp_reference_points = reference_points[i] * scale_f[None, :] new_reference_points.append(tmp_reference_points) new_reference_points = torch.stack(new_reference_points, dim=0) # [batch_size, time * num_queries_per_frame, 2], in image size reference_points = new_reference_points # prepare the mask features if self.rel_coord: reference_points = rearrange(reference_points, 'b (t q) n -> b t q n', t=t, q=q) locations = compute_locations(h, w, device=device, stride=self.mask_feat_stride) relative_coords = reference_points.reshape(b, t, q, 1, 1, 2) - \ locations.reshape(1, 1, 1, h, w, 2) # [batch_size, time, num_queries_per_frame, h, w, 2] relative_coords = relative_coords.permute(0, 1, 2, 5, 3, 4) # [batch_size, time, num_queries_per_frame, 2, h, w] # concat features mask_features = repeat(mask_features, 'b t c h w -> b t q c h w', q=q) # [batch_size, time, num_queries_per_frame, c, h, w] mask_features = torch.cat([mask_features, relative_coords], dim=3) else: mask_features = repeat(mask_features, 'b t c h w -> b t q c h w', q=q) # [batch_size, time, num_queries_per_frame, c, h, w] mask_features = mask_features.reshape(1, -1, h, w) # parse dynamic params mask_head_params = mask_head_params.flatten(0, 1) weights, biases = parse_dynamic_params( mask_head_params, self.dynamic_mask_channels, self.weight_nums, self.bias_nums ) # dynamic mask conv mask_logits = self.mask_heads_forward(mask_features, weights, biases, mask_head_params.shape[0]) mask_logits = mask_logits.reshape(-1, 1, h, w) # upsample predicted masks assert self.mask_feat_stride >= self.mask_out_stride assert self.mask_feat_stride % self.mask_out_stride == 0 mask_logits = aligned_bilinear(mask_logits, int(self.mask_feat_stride / self.mask_out_stride)) mask_logits = mask_logits.reshape(b, num_queries, mask_logits.shape[-2], mask_logits.shape[-1]) return mask_logits # [batch_size, time * num_queries_per_frame, h, w] def mask_heads_forward(self, features, weights, biases, num_insts): ''' :param features :param weights: [w0, w1, ...] :param bias: [b0, b1, ...] :return: ''' assert features.dim() == 4 n_layers = len(weights) x = features for i, (w, b) in enumerate(zip(weights, biases)): x = F.conv2d( x, w, bias=b, stride=1, padding=0, groups=num_insts ) if i < n_layers - 1: x = F.relu(x) return x def dice_loss(inputs, targets, num_masks, valid=None): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() if valid is not None: inputs = inputs.masked_fill(valid, 0) numerator = 2 * (inputs * targets).sum(1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_masks def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, valid=None): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss if valid is None: loss = loss.mean(1) else: loss = (loss.masked_fill(valid, 0)).sum(1) / (~valid).sum(1) return loss.sum() / num_boxes def sigmoid_focal_loss_refer(inputs, targets, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.mean(1).sum() def parse_dynamic_params(params, channels, weight_nums, bias_nums): assert params.dim() == 2 assert len(weight_nums) == len(bias_nums) assert params.size(1) == sum(weight_nums) + sum(bias_nums) num_insts = params.size(0) num_layers = len(weight_nums) params_splits = list(torch.split_with_sizes(params, weight_nums + bias_nums, dim=1)) weight_splits = params_splits[:num_layers] bias_splits = params_splits[num_layers:] for l in range(num_layers): if l < num_layers - 1: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts * channels) else: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts) return weight_splits, bias_splits def aligned_bilinear(tensor, factor): assert tensor.dim() == 4 assert factor >= 1 assert int(factor) == factor if factor == 1: return tensor h, w = tensor.size()[2:] tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate") oh = factor * h + 1 ow = factor * w + 1 tensor = F.interpolate( tensor, size=(oh, ow), mode='bilinear', align_corners=True ) tensor = F.pad( tensor, pad=(factor // 2, 0, factor // 2, 0), mode="replicate" ) return tensor[:, :, :oh - 1, :ow - 1] def compute_locations(h, w, device, stride=1): shifts_x = torch.arange( 0, w * stride, step=stride, dtype=torch.float32, device=device) shifts_y = torch.arange( 0, h * stride, step=stride, dtype=torch.float32, device=device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 return locations