Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import os.path as osp | |
| from os import PathLike | |
| from typing import Optional, Sequence | |
| import mmengine | |
| from mmcv.transforms import Compose | |
| from mmengine.fileio import get_file_backend | |
| from .builder import DATASETS | |
| def expanduser(path): | |
| if isinstance(path, (str, PathLike)): | |
| return osp.expanduser(path) | |
| else: | |
| return path | |
| def isabs(uri): | |
| return osp.isabs(uri) or ('://' in uri) | |
| class MultiTaskDataset: | |
| """Custom dataset for multi-task dataset. | |
| To use the dataset, please generate and provide an annotation file in the | |
| below format: | |
| .. code-block:: json | |
| { | |
| "metainfo": { | |
| "tasks": | |
| [ | |
| 'gender' | |
| 'wear' | |
| ] | |
| }, | |
| "data_list": [ | |
| { | |
| "img_path": "a.jpg", | |
| gt_label:{ | |
| "gender": 0, | |
| "wear": [1, 0, 1, 0] | |
| } | |
| }, | |
| { | |
| "img_path": "b.jpg", | |
| gt_label:{ | |
| "gender": 1, | |
| "wear": [1, 0, 1, 0] | |
| } | |
| } | |
| ] | |
| } | |
| Assume we put our dataset in the ``data/mydataset`` folder in the | |
| repository and organize it as the below format: :: | |
| mmpretrain/ | |
| βββ data | |
| βββ mydataset | |
| βββ annotation | |
| βΒ Β βββ train.json | |
| βΒ Β βββ test.json | |
| βΒ Β βββ val.json | |
| βββ train | |
| βΒ Β βββ a.jpg | |
| βΒ Β βββ ... | |
| βββ test | |
| βΒ Β βββ b.jpg | |
| βΒ Β βββ ... | |
| βββ val | |
| βββ c.jpg | |
| βββ ... | |
| We can use the below config to build datasets: | |
| .. code:: python | |
| >>> from mmpretrain.datasets import build_dataset | |
| >>> train_cfg = dict( | |
| ... type="MultiTaskDataset", | |
| ... ann_file="annotation/train.json", | |
| ... data_root="data/mydataset", | |
| ... # The `img_path` field in the train annotation file is relative | |
| ... # to the `train` folder. | |
| ... data_prefix='train', | |
| ... ) | |
| >>> train_dataset = build_dataset(train_cfg) | |
| Or we can put all files in the same folder: :: | |
| mmpretrain/ | |
| βββ data | |
| βββ mydataset | |
| βββ train.json | |
| βββ test.json | |
| βββ val.json | |
| βββ a.jpg | |
| βββ b.jpg | |
| βββ c.jpg | |
| βββ ... | |
| And we can use the below config to build datasets: | |
| .. code:: python | |
| >>> from mmpretrain.datasets import build_dataset | |
| >>> train_cfg = dict( | |
| ... type="MultiTaskDataset", | |
| ... ann_file="train.json", | |
| ... data_root="data/mydataset", | |
| ... # the `data_prefix` is not required since all paths are | |
| ... # relative to the `data_root`. | |
| ... ) | |
| >>> train_dataset = build_dataset(train_cfg) | |
| Args: | |
| ann_file (str): The annotation file path. It can be either absolute | |
| path or relative path to the ``data_root``. | |
| metainfo (dict, optional): The extra meta information. It should be | |
| a dict with the same format as the ``"metainfo"`` field in the | |
| annotation file. Defaults to None. | |
| data_root (str, optional): The root path of the data directory. It's | |
| the prefix of the ``data_prefix`` and the ``ann_file``. And it can | |
| be a remote path like "s3://openmmlab/xxx/". Defaults to None. | |
| data_prefix (str, optional): The base folder relative to the | |
| ``data_root`` for the ``"img_path"`` field in the annotation file. | |
| Defaults to None. | |
| pipeline (Sequence[dict]): A list of dict, where each element | |
| represents a operation defined in | |
| :mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple. | |
| test_mode (bool): in train mode or test mode. Defaults to False. | |
| """ | |
| METAINFO = dict() | |
| def __init__(self, | |
| ann_file: str, | |
| metainfo: Optional[dict] = None, | |
| data_root: Optional[str] = None, | |
| data_prefix: Optional[str] = None, | |
| pipeline: Sequence = (), | |
| test_mode: bool = False): | |
| self.data_root = expanduser(data_root) | |
| # Inference the file client | |
| if self.data_root is not None: | |
| self.file_backend = get_file_backend(uri=self.data_root) | |
| else: | |
| self.file_backend = None | |
| self.ann_file = self._join_root(expanduser(ann_file)) | |
| self.data_prefix = self._join_root(data_prefix) | |
| self.test_mode = test_mode | |
| self.pipeline = Compose(pipeline) | |
| self.data_list = self.load_data_list(self.ann_file, metainfo) | |
| def _join_root(self, path): | |
| """Join ``self.data_root`` with the specified path. | |
| If the path is an absolute path, just return the path. And if the | |
| path is None, return ``self.data_root``. | |
| Examples: | |
| >>> self.data_root = 'a/b/c' | |
| >>> self._join_root('d/e/') | |
| 'a/b/c/d/e' | |
| >>> self._join_root('https://openmmlab.com') | |
| 'https://openmmlab.com' | |
| >>> self._join_root(None) | |
| 'a/b/c' | |
| """ | |
| if path is None: | |
| return self.data_root | |
| if isabs(path): | |
| return path | |
| joined_path = self.file_backend.join_path(self.data_root, path) | |
| return joined_path | |
| def _get_meta_info(cls, in_metainfo: dict = None) -> dict: | |
| """Collect meta information from the dictionary of meta. | |
| Args: | |
| in_metainfo (dict): Meta information dict. | |
| Returns: | |
| dict: Parsed meta information. | |
| """ | |
| # `cls.METAINFO` will be overwritten by in_meta | |
| metainfo = copy.deepcopy(cls.METAINFO) | |
| if in_metainfo is None: | |
| return metainfo | |
| metainfo.update(in_metainfo) | |
| return metainfo | |
| def load_data_list(self, ann_file, metainfo_override=None): | |
| """Load annotations from an annotation file. | |
| Args: | |
| ann_file (str): Absolute annotation file path if ``self.root=None`` | |
| or relative path if ``self.root=/path/to/data/``. | |
| Returns: | |
| list[dict]: A list of annotation. | |
| """ | |
| annotations = mmengine.load(ann_file) | |
| if not isinstance(annotations, dict): | |
| raise TypeError(f'The annotations loaded from annotation file ' | |
| f'should be a dict, but got {type(annotations)}!') | |
| if 'data_list' not in annotations: | |
| raise ValueError('The annotation file must have the `data_list` ' | |
| 'field.') | |
| metainfo = annotations.get('metainfo', {}) | |
| raw_data_list = annotations['data_list'] | |
| # Set meta information. | |
| assert isinstance(metainfo, dict), 'The `metainfo` field in the '\ | |
| f'annotation file should be a dict, but got {type(metainfo)}' | |
| if metainfo_override is not None: | |
| assert isinstance(metainfo_override, dict), 'The `metainfo` ' \ | |
| f'argument should be a dict, but got {type(metainfo_override)}' | |
| metainfo.update(metainfo_override) | |
| self._metainfo = self._get_meta_info(metainfo) | |
| data_list = [] | |
| for i, raw_data in enumerate(raw_data_list): | |
| try: | |
| data_list.append(self.parse_data_info(raw_data)) | |
| except AssertionError as e: | |
| raise RuntimeError( | |
| f'The format check fails during parse the item {i} of ' | |
| f'the annotation file with error: {e}') | |
| return data_list | |
| def parse_data_info(self, raw_data): | |
| """Parse raw annotation to target format. | |
| This method will return a dict which contains the data information of a | |
| sample. | |
| Args: | |
| raw_data (dict): Raw data information load from ``ann_file`` | |
| Returns: | |
| dict: Parsed annotation. | |
| """ | |
| assert isinstance(raw_data, dict), \ | |
| f'The item should be a dict, but got {type(raw_data)}' | |
| assert 'img_path' in raw_data, \ | |
| "The item doesn't have `img_path` field." | |
| data = dict( | |
| img_path=self._join_root(raw_data['img_path']), | |
| gt_label=raw_data['gt_label'], | |
| ) | |
| return data | |
| def metainfo(self) -> dict: | |
| """Get meta information of dataset. | |
| Returns: | |
| dict: meta information collected from ``cls.METAINFO``, | |
| annotation file and metainfo argument during instantiation. | |
| """ | |
| return copy.deepcopy(self._metainfo) | |
| def prepare_data(self, idx): | |
| """Get data processed by ``self.pipeline``. | |
| Args: | |
| idx (int): The index of ``data_info``. | |
| Returns: | |
| Any: Depends on ``self.pipeline``. | |
| """ | |
| results = copy.deepcopy(self.data_list[idx]) | |
| return self.pipeline(results) | |
| def __len__(self): | |
| """Get the length of the whole dataset. | |
| Returns: | |
| int: The length of filtered dataset. | |
| """ | |
| return len(self.data_list) | |
| def __getitem__(self, idx): | |
| """Get the idx-th image and data information of dataset after | |
| ``self.pipeline``. | |
| Args: | |
| idx (int): The index of of the data. | |
| Returns: | |
| dict: The idx-th image and data information after | |
| ``self.pipeline``. | |
| """ | |
| return self.prepare_data(idx) | |
| def __repr__(self): | |
| """Print the basic information of the dataset. | |
| Returns: | |
| str: Formatted string. | |
| """ | |
| head = 'Dataset ' + self.__class__.__name__ | |
| body = [f'Number of samples: \t{self.__len__()}'] | |
| if self.data_root is not None: | |
| body.append(f'Root location: \t{self.data_root}') | |
| body.append(f'Annotation file: \t{self.ann_file}') | |
| if self.data_prefix is not None: | |
| body.append(f'Prefix of images: \t{self.data_prefix}') | |
| # -------------------- extra repr -------------------- | |
| tasks = self.metainfo['tasks'] | |
| body.append(f'For {len(tasks)} tasks') | |
| for task in tasks: | |
| body.append(f' {task} ') | |
| # ---------------------------------------------------- | |
| if len(self.pipeline.transforms) > 0: | |
| body.append('With transforms:') | |
| for t in self.pipeline.transforms: | |
| body.append(f' {t}') | |
| lines = [head] + [' ' * 4 + line for line in body] | |
| return '\n'.join(lines) | |