Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved | |
| import warnings | |
| from mmengine.hooks import Hook | |
| from mmengine.model import is_model_wrapper | |
| from mmpretrain.models import BaseRetriever | |
| from mmpretrain.registry import HOOKS | |
| class PrepareProtoBeforeValLoopHook(Hook): | |
| """The hook to prepare the prototype in retrievers. | |
| Since the encoders of the retriever changes during training, the prototype | |
| changes accordingly. So the `prototype_vecs` needs to be regenerated before | |
| validation loop. | |
| """ | |
| def before_val(self, runner) -> None: | |
| model = runner.model | |
| if is_model_wrapper(model): | |
| model = model.module | |
| if isinstance(model, BaseRetriever): | |
| if hasattr(model, 'prepare_prototype'): | |
| model.prepare_prototype() | |
| else: | |
| warnings.warn( | |
| 'Only the `mmpretrain.models.retrievers.BaseRetriever` ' | |
| 'can execute `PrepareRetrieverPrototypeHook`, but got ' | |
| f'`{type(model)}`') | |