from __future__ import annotations import inspect import logging import os from collections.abc import Callable from pathlib import Path from typing import TYPE_CHECKING, Any from sentence_transformers.backend import load_onnx_model, load_openvino_model try: from typing import Self except ImportError: from typing_extensions import Self import torch from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, PretrainedConfig, T5Config from transformers.utils.import_utils import is_peft_available from transformers.utils.peft_utils import find_adapter_config_file from sentence_transformers.models.InputModule import InputModule logger = logging.getLogger(__name__) if TYPE_CHECKING and is_peft_available(): from peft import PeftConfig from sentence_transformers.models import Transformer class C2LLMTransformer(Transformer): config_file_name: str = "sentence_bert_config.json" config_keys: list[str] = ["max_seq_length", "do_lower_case"] save_in_root: bool = True def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: trans_features = {key: value for key, value in features.items() if key in self.model_forward_params} outputs = self.auto_model(**trans_features, **kwargs, return_dict=True) sentence_embedding = outputs["sentence_embedding"] features["sentence_embedding"] = sentence_embedding return features