Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List | |
| from mmengine.fileio import load | |
| from mmpretrain.registry import DATASETS | |
| from .base_dataset import BaseDataset | |
| class VGVQA(BaseDataset): | |
| """Visual Genome VQA dataset.""" | |
| def load_data_list(self) -> List[dict]: | |
| """Load data list. | |
| Compare to BaseDataset, the only difference is that coco_vqa annotation | |
| file is already a list of data. There is no 'metainfo'. | |
| """ | |
| raw_data_list = load(self.ann_file) | |
| if not isinstance(raw_data_list, list): | |
| raise TypeError( | |
| f'The VQA annotations loaded from annotation file ' | |
| f'should be a dict, but got {type(raw_data_list)}!') | |
| # load and parse data_infos. | |
| data_list = [] | |
| for raw_data_info in raw_data_list: | |
| # parse raw data information to target format | |
| data_info = self.parse_data_info(raw_data_info) | |
| if isinstance(data_info, dict): | |
| # For VQA tasks, each `data_info` looks like: | |
| # { | |
| # "question_id": 986769, | |
| # "question": "How many people are there?", | |
| # "answer": "two", | |
| # "image": "image/1.jpg", | |
| # "dataset": "vg" | |
| # } | |
| # change 'image' key to 'img_path' | |
| # TODO: This process will be removed, after the annotation file | |
| # is preprocess. | |
| data_info['img_path'] = data_info['image'] | |
| del data_info['image'] | |
| if 'answer' in data_info: | |
| # add answer_weight & answer_count, delete duplicate answer | |
| if data_info['dataset'] == 'vqa': | |
| answer_weight = {} | |
| for answer in data_info['answer']: | |
| if answer in answer_weight.keys(): | |
| answer_weight[answer] += 1 / len( | |
| data_info['answer']) | |
| else: | |
| answer_weight[answer] = 1 / len( | |
| data_info['answer']) | |
| data_info['answer'] = list(answer_weight.keys()) | |
| data_info['answer_weight'] = list( | |
| answer_weight.values()) | |
| data_info['answer_count'] = len(answer_weight) | |
| elif data_info['dataset'] == 'vg': | |
| data_info['answers'] = [data_info['answer']] | |
| data_info['answer_weight'] = [0.2] | |
| data_info['answer_count'] = 1 | |
| data_list.append(data_info) | |
| else: | |
| raise TypeError( | |
| f'Each VQA data element loaded from annotation file ' | |
| f'should be a dict, but got {type(data_info)}!') | |
| return data_list | |