Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| from typing import List, Union | |
| from mmcv.transforms import BaseTransform | |
| PIPELINE_TYPE = List[Union[dict, BaseTransform]] | |
| def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int: | |
| """Returns the index of the transform in a pipeline. | |
| Args: | |
| pipeline (List[dict] | List[BaseTransform]): The transforms list. | |
| target (str): The target transform class name. | |
| Returns: | |
| int: The transform index. Returns -1 if not found. | |
| """ | |
| for i, transform in enumerate(pipeline): | |
| if isinstance(transform, dict): | |
| if isinstance(transform['type'], type): | |
| if transform['type'].__name__ == target: | |
| return i | |
| else: | |
| if transform['type'] == target: | |
| return i | |
| else: | |
| if transform.__class__.__name__ == target: | |
| return i | |
| return -1 | |
| def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False): | |
| """Remove the target transform type from the pipeline. | |
| Args: | |
| pipeline (List[dict] | List[BaseTransform]): The transforms list. | |
| target (str): The target transform class name. | |
| inplace (bool): Whether to modify the pipeline inplace. | |
| Returns: | |
| The modified transform. | |
| """ | |
| idx = get_transform_idx(pipeline, target) | |
| if not inplace: | |
| pipeline = copy.deepcopy(pipeline) | |
| while idx >= 0: | |
| pipeline.pop(idx) | |
| idx = get_transform_idx(pipeline, target) | |
| return pipeline | |