Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import fnmatch | |
| import os.path as osp | |
| import re | |
| import warnings | |
| from os import PathLike | |
| from pathlib import Path | |
| from typing import List, Tuple, Union | |
| from mmengine.config import Config | |
| from modelindex.load_model_index import load | |
| from modelindex.models.Model import Model | |
| class ModelHub: | |
| """A hub to host the meta information of all pre-defined models.""" | |
| _models_dict = {} | |
| __mmpretrain_registered = False | |
| def register_model_index(cls, | |
| model_index_path: Union[str, PathLike], | |
| config_prefix: Union[str, PathLike, None] = None): | |
| """Parse the model-index file and register all models. | |
| Args: | |
| model_index_path (str | PathLike): The path of the model-index | |
| file. | |
| config_prefix (str | PathLike | None): The prefix of all config | |
| file paths in the model-index file. | |
| """ | |
| model_index = load(str(model_index_path)) | |
| model_index.build_models_with_collections() | |
| for metainfo in model_index.models: | |
| model_name = metainfo.name.lower() | |
| if metainfo.name in cls._models_dict: | |
| raise ValueError( | |
| 'The model name {} is conflict in {} and {}.'.format( | |
| model_name, osp.abspath(metainfo.filepath), | |
| osp.abspath(cls._models_dict[model_name].filepath))) | |
| metainfo.config = cls._expand_config_path(metainfo, config_prefix) | |
| cls._models_dict[model_name] = metainfo | |
| def get(cls, model_name): | |
| """Get the model's metainfo by the model name. | |
| Args: | |
| model_name (str): The name of model. | |
| Returns: | |
| modelindex.models.Model: The metainfo of the specified model. | |
| """ | |
| cls._register_mmpretrain_models() | |
| # lazy load config | |
| metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower())) | |
| if metainfo is None: | |
| raise ValueError( | |
| f'Failed to find model "{model_name}". please use ' | |
| '`mmpretrain.list_models` to get all available names.') | |
| if isinstance(metainfo.config, str): | |
| metainfo.config = Config.fromfile(metainfo.config) | |
| return metainfo | |
| def _expand_config_path(metainfo: Model, | |
| config_prefix: Union[str, PathLike] = None): | |
| if config_prefix is None: | |
| config_prefix = osp.dirname(metainfo.filepath) | |
| if metainfo.config is None or osp.isabs(metainfo.config): | |
| config_path: str = metainfo.config | |
| else: | |
| config_path = osp.abspath(osp.join(config_prefix, metainfo.config)) | |
| return config_path | |
| def _register_mmpretrain_models(cls): | |
| # register models in mmpretrain | |
| if not cls.__mmpretrain_registered: | |
| from importlib_metadata import distribution | |
| root = distribution('mmpretrain').locate_file('mmpretrain') | |
| model_index_path = root / '.mim' / 'model-index.yml' | |
| ModelHub.register_model_index( | |
| model_index_path, config_prefix=root / '.mim') | |
| cls.__mmpretrain_registered = True | |
| def has(cls, model_name): | |
| """Whether a model name is in the ModelHub.""" | |
| return model_name in cls._models_dict | |
| def get_model(model: Union[str, Config], | |
| pretrained: Union[str, bool] = False, | |
| device=None, | |
| device_map=None, | |
| offload_folder=None, | |
| url_mapping: Tuple[str, str] = None, | |
| **kwargs): | |
| """Get a pre-defined model or create a model from config. | |
| Args: | |
| model (str | Config): The name of model, the config file path or a | |
| config instance. | |
| pretrained (bool | str): When use name to specify model, you can | |
| use ``True`` to load the pre-defined pretrained weights. And you | |
| can also use a string to specify the path or link of weights to | |
| load. Defaults to False. | |
| device (str | torch.device | None): Transfer the model to the target | |
| device. Defaults to None. | |
| device_map (str | dict | None): A map that specifies where each | |
| submodule should go. It doesn't need to be refined to each | |
| parameter/buffer name, once a given module name is inside, every | |
| submodule of it will be sent to the same device. You can use | |
| `device_map="auto"` to automatically generate the device map. | |
| Defaults to None. | |
| offload_folder (str | None): If the `device_map` contains any value | |
| `"disk"`, the folder where we will offload weights. | |
| url_mapping (Tuple[str, str], optional): The mapping of pretrained | |
| checkpoint link. For example, load checkpoint from a local dir | |
| instead of download by ``('https://.*/', './checkpoint')``. | |
| Defaults to None. | |
| **kwargs: Other keyword arguments of the model config. | |
| Returns: | |
| mmengine.model.BaseModel: The result model. | |
| Examples: | |
| Get a ResNet-50 model and extract images feature: | |
| >>> import torch | |
| >>> from mmpretrain import get_model | |
| >>> inputs = torch.rand(16, 3, 224, 224) | |
| >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) | |
| >>> feats = model.extract_feat(inputs) | |
| >>> for feat in feats: | |
| ... print(feat.shape) | |
| torch.Size([16, 256]) | |
| torch.Size([16, 512]) | |
| torch.Size([16, 1024]) | |
| torch.Size([16, 2048]) | |
| Get Swin-Transformer model with pre-trained weights and inference: | |
| >>> from mmpretrain import get_model, inference_model | |
| >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) | |
| >>> result = inference_model(model, 'demo/demo.JPEG') | |
| >>> print(result['pred_class']) | |
| 'sea snake' | |
| """ # noqa: E501 | |
| if device_map is not None: | |
| from .utils import dispatch_model | |
| dispatch_model._verify_require() | |
| metainfo = None | |
| if isinstance(model, Config): | |
| config = copy.deepcopy(model) | |
| if pretrained is True and 'load_from' in config: | |
| pretrained = config.load_from | |
| elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py': | |
| config = Config.fromfile(model) | |
| if pretrained is True and 'load_from' in config: | |
| pretrained = config.load_from | |
| elif isinstance(model, str): | |
| metainfo = ModelHub.get(model) | |
| config = metainfo.config | |
| if pretrained is True and metainfo.weights is not None: | |
| pretrained = metainfo.weights | |
| else: | |
| raise TypeError('model must be a name, a path or a Config object, ' | |
| f'but got {type(config)}') | |
| if pretrained is True: | |
| warnings.warn('Unable to find pre-defined checkpoint of the model.') | |
| pretrained = None | |
| elif pretrained is False: | |
| pretrained = None | |
| if kwargs: | |
| config.merge_from_dict({'model': kwargs}) | |
| config.model.setdefault('data_preprocessor', | |
| config.get('data_preprocessor', None)) | |
| from mmengine.registry import DefaultScope | |
| from mmpretrain.registry import MODELS | |
| with DefaultScope.overwrite_default_scope('mmpretrain'): | |
| model = MODELS.build(config.model) | |
| dataset_meta = {} | |
| if pretrained: | |
| # Mapping the weights to GPU may cause unexpected video memory leak | |
| # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 | |
| from mmengine.runner import load_checkpoint | |
| if url_mapping is not None: | |
| pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained) | |
| checkpoint = load_checkpoint(model, pretrained, map_location='cpu') | |
| if 'dataset_meta' in checkpoint.get('meta', {}): | |
| # mmpretrain 1.x | |
| dataset_meta = checkpoint['meta']['dataset_meta'] | |
| elif 'CLASSES' in checkpoint.get('meta', {}): | |
| # mmcls 0.x | |
| dataset_meta = {'classes': checkpoint['meta']['CLASSES']} | |
| if len(dataset_meta) == 0 and 'test_dataloader' in config: | |
| from mmpretrain.registry import DATASETS | |
| dataset_class = DATASETS.get(config.test_dataloader.dataset.type) | |
| dataset_meta = getattr(dataset_class, 'METAINFO', {}) | |
| if device_map is not None: | |
| model = dispatch_model( | |
| model, device_map=device_map, offload_folder=offload_folder) | |
| elif device is not None: | |
| model.to(device) | |
| model._dataset_meta = dataset_meta # save the dataset meta | |
| model._config = config # save the config in the model | |
| model._metainfo = metainfo # save the metainfo in the model | |
| model.eval() | |
| return model | |
| def init_model(config, checkpoint=None, device=None, **kwargs): | |
| """Initialize a classifier from config file (deprecated). | |
| It's only for compatibility, please use :func:`get_model` instead. | |
| Args: | |
| config (str | :obj:`mmengine.Config`): Config file path or the config | |
| object. | |
| checkpoint (str, optional): Checkpoint path. If left as None, the model | |
| will not load any weights. | |
| device (str | torch.device | None): Transfer the model to the target | |
| device. Defaults to None. | |
| **kwargs: Other keyword arguments of the model config. | |
| Returns: | |
| nn.Module: The constructed model. | |
| """ | |
| return get_model(config, checkpoint, device, **kwargs) | |
| def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]: | |
| """List all models available in MMPretrain. | |
| Args: | |
| pattern (str | None): A wildcard pattern to match model names. | |
| Defaults to None. | |
| exclude_patterns (list | None): A list of wildcard patterns to | |
| exclude names from the matched names. Defaults to None. | |
| task (str | none): The evaluation task of the model. | |
| Returns: | |
| List[str]: a list of model names. | |
| Examples: | |
| List all models: | |
| >>> from mmpretrain import list_models | |
| >>> list_models() | |
| List ResNet-50 models on ImageNet-1k dataset: | |
| >>> from mmpretrain import list_models | |
| >>> list_models('resnet*in1k') | |
| ['resnet50_8xb32_in1k', | |
| 'resnet50_8xb32-fp16_in1k', | |
| 'resnet50_8xb256-rsb-a1-600e_in1k', | |
| 'resnet50_8xb256-rsb-a2-300e_in1k', | |
| 'resnet50_8xb256-rsb-a3-100e_in1k'] | |
| List Swin-Transformer models trained from stratch and exclude | |
| Swin-Transformer-V2 models: | |
| >>> from mmpretrain import list_models | |
| >>> list_models('swin', exclude_patterns=['swinv2', '*-pre']) | |
| ['swin-base_16xb64_in1k', | |
| 'swin-base_3rdparty_in1k', | |
| 'swin-base_3rdparty_in1k-384', | |
| 'swin-large_8xb8_cub-384px', | |
| 'swin-small_16xb64_in1k', | |
| 'swin-small_3rdparty_in1k', | |
| 'swin-tiny_16xb64_in1k', | |
| 'swin-tiny_3rdparty_in1k'] | |
| List all EVA models for image classification task. | |
| >>> from mmpretrain import list_models | |
| >>> list_models('eva', task='Image Classification') | |
| ['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px', | |
| 'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px', | |
| 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px', | |
| 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px', | |
| 'eva-l-p14_mim-pre_3rdparty_in1k-196px', | |
| 'eva-l-p14_mim-pre_3rdparty_in1k-336px'] | |
| """ | |
| ModelHub._register_mmpretrain_models() | |
| matches = set(ModelHub._models_dict.keys()) | |
| if pattern is not None: | |
| # Always match keys with any postfix. | |
| matches = set(fnmatch.filter(matches, pattern + '*')) | |
| exclude_patterns = exclude_patterns or [] | |
| for exclude_pattern in exclude_patterns: | |
| exclude = set(fnmatch.filter(matches, exclude_pattern + '*')) | |
| matches = matches - exclude | |
| if task is not None: | |
| task_matches = [] | |
| for key in matches: | |
| metainfo = ModelHub._models_dict[key] | |
| if metainfo.results is None and task == 'null': | |
| task_matches.append(key) | |
| elif metainfo.results is None: | |
| continue | |
| elif task in [result.task for result in metainfo.results]: | |
| task_matches.append(key) | |
| matches = task_matches | |
| return sorted(list(matches)) | |
| def inference_model(model, *args, **kwargs): | |
| """Inference an image with the inferencer. | |
| Automatically select inferencer to inference according to the type of | |
| model. It's a shortcut for a quick start, and for advanced usage, please | |
| use the correspondding inferencer class. | |
| Here is the mapping from task to inferencer: | |
| - Image Classification: :class:`ImageClassificationInferencer` | |
| - Image Retrieval: :class:`ImageRetrievalInferencer` | |
| - Image Caption: :class:`ImageCaptionInferencer` | |
| - Visual Question Answering: :class:`VisualQuestionAnsweringInferencer` | |
| - Visual Grounding: :class:`VisualGroundingInferencer` | |
| - Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer` | |
| - Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer` | |
| - NLVR: :class:`NLVRInferencer` | |
| Args: | |
| model (BaseModel | str | Config): The loaded model, the model | |
| name or the config of the model. | |
| *args: Positional arguments to call the inferencer. | |
| **kwargs: Other keyword arguments to initialize and call the | |
| correspondding inferencer. | |
| Returns: | |
| result (dict): The inference results. | |
| """ # noqa: E501 | |
| from mmengine.model import BaseModel | |
| if isinstance(model, BaseModel): | |
| metainfo = getattr(model, '_metainfo', None) | |
| else: | |
| metainfo = ModelHub.get(model) | |
| from inspect import signature | |
| from .image_caption import ImageCaptionInferencer | |
| from .image_classification import ImageClassificationInferencer | |
| from .image_retrieval import ImageRetrievalInferencer | |
| from .multimodal_retrieval import (ImageToTextRetrievalInferencer, | |
| TextToImageRetrievalInferencer) | |
| from .nlvr import NLVRInferencer | |
| from .visual_grounding import VisualGroundingInferencer | |
| from .visual_question_answering import VisualQuestionAnsweringInferencer | |
| task_mapping = { | |
| 'Image Classification': ImageClassificationInferencer, | |
| 'Image Retrieval': ImageRetrievalInferencer, | |
| 'Image Caption': ImageCaptionInferencer, | |
| 'Visual Question Answering': VisualQuestionAnsweringInferencer, | |
| 'Visual Grounding': VisualGroundingInferencer, | |
| 'Text-To-Image Retrieval': TextToImageRetrievalInferencer, | |
| 'Image-To-Text Retrieval': ImageToTextRetrievalInferencer, | |
| 'NLVR': NLVRInferencer, | |
| } | |
| inferencer_type = None | |
| if metainfo is not None and metainfo.results is not None: | |
| tasks = set(result.task for result in metainfo.results) | |
| inferencer_type = [ | |
| task_mapping.get(task) for task in tasks if task in task_mapping | |
| ] | |
| if len(inferencer_type) > 1: | |
| inferencer_names = [cls.__name__ for cls in inferencer_type] | |
| warnings.warn('The model supports multiple tasks, auto select ' | |
| f'{inferencer_names[0]}, you can also use other ' | |
| f'inferencer {inferencer_names} directly.') | |
| inferencer_type = inferencer_type[0] | |
| if inferencer_type is None: | |
| raise NotImplementedError('No available inferencer for the model') | |
| init_kwargs = { | |
| k: kwargs.pop(k) | |
| for k in list(kwargs) | |
| if k in signature(inferencer_type).parameters.keys() | |
| } | |
| inferencer = inferencer_type(model, **init_kwargs) | |
| return inferencer(*args, **kwargs)[0] | |