Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from timm import create_model | |
| class Identity(nn.Module): | |
| def __init__(self): | |
| super(Identity, self).__init__() | |
| def forward(self, x): | |
| return x | |
| class SwinT(nn.Module): | |
| def __init__(self, model_name='swin_base_patch4_window7_224', global_pool='avg', pretrained=True): | |
| super(SwinT, self).__init__() | |
| self.swin_model = create_model( | |
| model_name, pretrained=pretrained, global_pool=global_pool | |
| ) | |
| self.swin_model.head = Identity() # Remove classification head | |
| self.global_pool = global_pool | |
| def forward(self, x): | |
| features = self.swin_model(x) # Shape: (batch_size, 7, 7, 1024) | |
| if self.global_pool == 'avg': | |
| features = features.mean(dim=[1, 2]) # Global pool | |
| return features | |
| def extract_features_swint_pool(video, model, device): | |
| swint_feature_list = [] | |
| with torch.cuda.amp.autocast(): | |
| for segment in video: | |
| # Flatten the segment into a batch of frames | |
| frames = segment.squeeze(0).to(device) # Shape: (32, 3, 224, 224) | |
| swint_features = model(frames) # Shape: (32, feature_dim) | |
| swint_feature_list.append(swint_features) | |
| # Concatenate features across segments | |
| features = torch.cat(swint_feature_list, dim=0) # Shape: (num_frames, feature_dim) | |
| return features | |