Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from collections import OrderedDict | |
| from typing import List | |
| import mmengine | |
| from mmengine import get_file_backend | |
| from mmpretrain.registry import DATASETS | |
| from .base_dataset import BaseDataset | |
| class Flickr30kRetrieval(BaseDataset): | |
| """Flickr30k Retrieval dataset. | |
| Args: | |
| data_root (str): The root directory for ``data_prefix``, ``ann_file`` | |
| and ``question_file``. | |
| data_prefix (str): The directory of images. | |
| ann_file (str): Annotation file path for training and validation. | |
| split (str): 'train', 'val' or 'test'. | |
| **kwargs: Other keyword arguments in :class:`BaseDataset`. | |
| """ | |
| def __init__(self, data_root: str, data_prefix: str, ann_file: str, | |
| split: str, **kwarg): | |
| assert split in ['train', 'val', 'test'], \ | |
| '`split` must be train, val or test' | |
| self.split = split | |
| super().__init__( | |
| data_root=data_root, | |
| data_prefix=dict(img_path=data_prefix), | |
| ann_file=ann_file, | |
| **kwarg, | |
| ) | |
| def load_data_list(self) -> List[dict]: | |
| """Load data list.""" | |
| # get file backend | |
| img_prefix = self.data_prefix['img_path'] | |
| file_backend = get_file_backend(img_prefix) | |
| annotations = mmengine.load(self.ann_file) | |
| # mapping img_id to img filename | |
| img_dict = OrderedDict() | |
| img_idx = 0 | |
| sentence_idx = 0 | |
| train_list = [] | |
| for img in annotations['images']: | |
| # img_example={ | |
| # "sentids": [0, 1, 2], | |
| # "imgid": 0, | |
| # "sentences": [ | |
| # {"raw": "Two men in green shirts standing in a yard.", | |
| # "imgid": 0, "sentid": 0}, | |
| # {"raw": "A man in a blue shirt standing in a garden.", | |
| # "imgid": 0, "sentid": 1}, | |
| # {"raw": "Two friends enjoy time spent together.", | |
| # "imgid": 0, "sentid": 2} | |
| # ], | |
| # "split": "train", | |
| # "filename": "1000092795.jpg" | |
| # }, | |
| if img['split'] != self.split: | |
| continue | |
| # create new idx for image | |
| train_image = dict( | |
| ori_id=img['imgid'], | |
| image_id=img_idx, # used for evaluation | |
| img_path=file_backend.join_path(img_prefix, img['filename']), | |
| text=[], | |
| gt_text_id=[], | |
| gt_image_id=[], | |
| ) | |
| for sentence in img['sentences']: | |
| ann = {} | |
| ann['text'] = sentence['raw'] | |
| ann['ori_id'] = sentence['sentid'] | |
| ann['text_id'] = sentence_idx # used for evaluation | |
| ann['image_ori_id'] = train_image['ori_id'] | |
| ann['image_id'] = train_image['image_id'] | |
| ann['img_path'] = train_image['img_path'] | |
| ann['is_matched'] = True | |
| # 1. prepare train data list item | |
| train_list.append(ann) | |
| # 2. prepare eval data list item based on img dict | |
| train_image['text'].append(ann['text']) | |
| train_image['gt_text_id'].append(ann['text_id']) | |
| train_image['gt_image_id'].append(ann['image_id']) | |
| sentence_idx += 1 | |
| img_dict[img['imgid']] = train_image | |
| img_idx += 1 | |
| self.img_size = len(img_dict) | |
| self.text_size = len(train_list) | |
| # return needed format data list | |
| if self.test_mode: | |
| return list(img_dict.values()) | |
| return train_list | |