| from __future__ import annotations | |
| import inspect | |
| import logging | |
| import os | |
| from collections.abc import Callable | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Any | |
| from sentence_transformers.backend import load_onnx_model, load_openvino_model | |
| try: | |
| from typing import Self | |
| except ImportError: | |
| from typing_extensions import Self | |
| import torch | |
| from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, PretrainedConfig, T5Config | |
| from transformers.utils.import_utils import is_peft_available | |
| from transformers.utils.peft_utils import find_adapter_config_file | |
| from sentence_transformers.models.InputModule import InputModule | |
| logger = logging.getLogger(__name__) | |
| if TYPE_CHECKING and is_peft_available(): | |
| from peft import PeftConfig | |
| from sentence_transformers.models import Transformer | |
| class C2LLMTransformer(Transformer): | |
| config_file_name: str = "sentence_bert_config.json" | |
| config_keys: list[str] = ["max_seq_length", "do_lower_case"] | |
| save_in_root: bool = True | |
| def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: | |
| trans_features = {key: value for key, value in features.items() if key in self.model_forward_params} | |
| outputs = self.auto_model(**trans_features, **kwargs, return_dict=True) | |
| sentence_embedding = outputs["sentence_embedding"] | |
| features["sentence_embedding"] = sentence_embedding | |
| return features |