Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py | |
| # Modified by Jian Ding from: https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py | |
| import fvcore.nn.weight_init as weight_init | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from detectron2.config import configurable | |
| from detectron2.layers import Conv2d | |
| from .model import Aggregator | |
| from cat_seg.third_party import clip | |
| from cat_seg.third_party import imagenet_templates | |
| import numpy as np | |
| import open_clip | |
| class CATSegPredictor(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| train_class_json: str, | |
| test_class_json: str, | |
| clip_pretrained: str, | |
| prompt_ensemble_type: str, | |
| text_guidance_dim: int, | |
| text_guidance_proj_dim: int, | |
| appearance_guidance_dim: int, | |
| appearance_guidance_proj_dim: int, | |
| prompt_depth: int, | |
| prompt_length: int, | |
| decoder_dims: list, | |
| decoder_guidance_dims: list, | |
| decoder_guidance_proj_dims: list, | |
| num_heads: int, | |
| num_layers: tuple, | |
| hidden_dims: tuple, | |
| pooling_sizes: tuple, | |
| feature_resolution: tuple, | |
| window_sizes: tuple, | |
| attention_type: str, | |
| ): | |
| """ | |
| Args: | |
| """ | |
| super().__init__() | |
| import json | |
| # use class_texts in train_forward, and test_class_texts in test_forward | |
| #with open(train_class_json, 'r') as f_in: | |
| # self.class_texts = json.load(f_in) | |
| #with open(test_class_json, 'r') as f_in: | |
| # self.test_class_texts = json.load(f_in) | |
| #assert self.class_texts != None | |
| #if self.test_class_texts == None: | |
| # self.test_class_texts = self.class_texts | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.device = device | |
| self.tokenizer = None | |
| if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H": | |
| # for OpenCLIP models | |
| name, pretrain = ('ViT-H-14', 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else ('ViT-bigG-14', 'laion2b_s39b_b160k') | |
| clip_model, _, clip_preprocess = open_clip.create_model_and_transforms( | |
| name, | |
| pretrained=pretrain, | |
| device=device, | |
| force_image_size=336,) | |
| self.tokenizer = open_clip.get_tokenizer(name) | |
| else: | |
| # for OpenAI models | |
| clip_model, clip_preprocess = clip.load(clip_pretrained, device=device, jit=False, prompt_depth=prompt_depth, prompt_length=prompt_length) | |
| self.prompt_ensemble_type = prompt_ensemble_type | |
| if self.prompt_ensemble_type == "imagenet_select": | |
| prompt_templates = imagenet_templates.IMAGENET_TEMPLATES_SELECT | |
| elif self.prompt_ensemble_type == "imagenet": | |
| prompt_templates = imagenet_templates.IMAGENET_TEMPLATES | |
| elif self.prompt_ensemble_type == "single": | |
| prompt_templates = ['A photo of a {} in the scene',] | |
| else: | |
| raise NotImplementedError | |
| #self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float() | |
| #self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float() | |
| self.clip_model = clip_model.float() | |
| self.clip_preprocess = clip_preprocess | |
| transformer = Aggregator( | |
| text_guidance_dim=text_guidance_dim, | |
| text_guidance_proj_dim=text_guidance_proj_dim, | |
| appearance_guidance_dim=appearance_guidance_dim, | |
| appearance_guidance_proj_dim=appearance_guidance_proj_dim, | |
| decoder_dims=decoder_dims, | |
| decoder_guidance_dims=decoder_guidance_dims, | |
| decoder_guidance_proj_dims=decoder_guidance_proj_dims, | |
| num_layers=num_layers, | |
| nheads=num_heads, | |
| hidden_dim=hidden_dims, | |
| pooling_size=pooling_sizes, | |
| feature_resolution=feature_resolution, | |
| window_size=window_sizes, | |
| attention_type=attention_type | |
| ) | |
| self.transformer = transformer | |
| def from_config(cls, cfg):#, in_channels, mask_classification): | |
| ret = {} | |
| ret["train_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON | |
| ret["test_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON | |
| ret["clip_pretrained"] = cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED | |
| ret["prompt_ensemble_type"] = cfg.MODEL.PROMPT_ENSEMBLE_TYPE | |
| # Aggregator parameters: | |
| ret["text_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM | |
| ret["text_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM | |
| ret["appearance_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM | |
| ret["appearance_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM | |
| ret["decoder_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS | |
| ret["decoder_guidance_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS | |
| ret["decoder_guidance_proj_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS | |
| ret["prompt_depth"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH | |
| ret["prompt_length"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH | |
| ret["num_layers"] = cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS | |
| ret["num_heads"] = cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS | |
| ret["hidden_dims"] = cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS | |
| ret["pooling_sizes"] = cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES | |
| ret["feature_resolution"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION | |
| ret["window_sizes"] = cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES | |
| ret["attention_type"] = cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE | |
| return ret | |
| def forward(self, x, vis_affinity): | |
| vis = [vis_affinity[k] for k in vis_affinity.keys()][::-1] | |
| text = self.text_features if self.training else self.text_features_test | |
| text = text.repeat(x.shape[0], 1, 1, 1) | |
| out = self.transformer(x, text, vis) | |
| return out | |
| def class_embeddings(self, classnames, templates, clip_model): | |
| zeroshot_weights = [] | |
| for classname in classnames: | |
| if ', ' in classname: | |
| classname_splits = classname.split(', ') | |
| texts = [] | |
| for template in templates: | |
| for cls_split in classname_splits: | |
| texts.append(template.format(cls_split)) | |
| else: | |
| texts = [template.format(classname) for template in templates] # format with class | |
| if self.tokenizer is not None: | |
| texts = self.tokenizer(texts).to(self.device) | |
| else: | |
| texts = clip.tokenize(texts).to(self.device) | |
| class_embeddings = clip_model.encode_text(texts) | |
| class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
| if len(templates) != class_embeddings.shape[0]: | |
| class_embeddings = class_embeddings.reshape(len(templates), -1, class_embeddings.shape[-1]).mean(dim=1) | |
| class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
| class_embedding = class_embeddings | |
| zeroshot_weights.append(class_embedding) | |
| zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device) | |
| return zeroshot_weights |