|
|
import os |
|
|
import json |
|
|
from dataclasses import dataclass |
|
|
from typing import Tuple, Any, List |
|
|
|
|
|
import onnxruntime as ort |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AutomotiveSLMConfig: |
|
|
model_name: str = "Automotive-SLM-Edge-3M" |
|
|
d_model: int = 256 |
|
|
n_layer: int = 4 |
|
|
n_head: int = 4 |
|
|
vocab_size: int = 50257 |
|
|
n_positions: int = 256 |
|
|
use_moe: bool = True |
|
|
n_experts: int = 4 |
|
|
expert_capacity: int = 2 |
|
|
moe_intermediate_size: int = 384 |
|
|
router_aux_loss_coef: float = 0.01 |
|
|
rotary_dim: int = 64 |
|
|
rope_base: float = 10000.0 |
|
|
dropout: float = 0.05 |
|
|
layer_norm_epsilon: float = 1e-5 |
|
|
|
|
|
max_gen_length: int = 50 |
|
|
temperature: float = 0.8 |
|
|
top_p: float = 0.9 |
|
|
top_k: int = 50 |
|
|
repetition_penalty: float = 1.1 |
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
def __init__(self, models_path: str): |
|
|
if not isinstance(models_path, str) or not models_path: |
|
|
raise ValueError(f"models_path must be a non-empty string, got: {models_path!r}") |
|
|
self.models_path = models_path |
|
|
self.cache = {} |
|
|
os.makedirs(self.models_path, exist_ok=True) |
|
|
|
|
|
def get_available_models(self) -> List[str]: |
|
|
if not os.path.isdir(self.models_path): |
|
|
return [] |
|
|
files = [] |
|
|
for fname in os.listdir(self.models_path): |
|
|
path = os.path.join(self.models_path, fname) |
|
|
if not os.path.isfile(path): |
|
|
continue |
|
|
if os.path.splitext(fname)[1].lower() == ".onnx": |
|
|
files.append(fname) |
|
|
return sorted(files) |
|
|
|
|
|
def _load_config(self) -> AutomotiveSLMConfig: |
|
|
|
|
|
assets_root = os.path.dirname(self.models_path) |
|
|
cfg_path = os.path.join(assets_root, "config.json") |
|
|
if os.path.exists(cfg_path): |
|
|
with open(cfg_path, "r") as f: |
|
|
cfg = json.load(f) |
|
|
return AutomotiveSLMConfig(**cfg) |
|
|
return AutomotiveSLMConfig() |
|
|
|
|
|
def load_model(self, model_filename: str) -> Tuple[Any, Any, AutomotiveSLMConfig]: |
|
|
if not isinstance(model_filename, str) or not model_filename: |
|
|
raise ValueError(f"model_filename must be a non-empty string, got: {model_filename!r}") |
|
|
|
|
|
if model_filename in self.cache: |
|
|
return self.cache[model_filename] |
|
|
|
|
|
model_path = os.path.join(self.models_path, model_filename) |
|
|
if not os.path.isfile(model_path): |
|
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
providers = ["CPUExecutionProvider"] |
|
|
so = ort.SessionOptions() |
|
|
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
session = ort.InferenceSession(model_path, providers=providers, sess_options=so) |
|
|
|
|
|
config = self._load_config() |
|
|
|
|
|
self.cache[model_filename] = (session, tokenizer, config) |
|
|
return session, tokenizer, config |
|
|
|