autoSLM / src /model_manager.py
Nihal2000's picture
Update src/model_manager.py
096cc16 verified
import os
import json
from dataclasses import dataclass
from typing import Tuple, Any, List
import onnxruntime as ort
from transformers import AutoTokenizer
@dataclass
class AutomotiveSLMConfig:
model_name: str = "Automotive-SLM-Edge-3M"
d_model: int = 256
n_layer: int = 4
n_head: int = 4
vocab_size: int = 50257
n_positions: int = 256
use_moe: bool = True
n_experts: int = 4
expert_capacity: int = 2
moe_intermediate_size: int = 384
router_aux_loss_coef: float = 0.01
rotary_dim: int = 64
rope_base: float = 10000.0
dropout: float = 0.05
layer_norm_epsilon: float = 1e-5
# UI defaults
max_gen_length: int = 50
temperature: float = 0.8
top_p: float = 0.9
top_k: int = 50
repetition_penalty: float = 1.1
class ModelManager:
def __init__(self, models_path: str):
if not isinstance(models_path, str) or not models_path:
raise ValueError(f"models_path must be a non-empty string, got: {models_path!r}")
self.models_path = models_path
self.cache = {}
os.makedirs(self.models_path, exist_ok=True)
def get_available_models(self) -> List[str]:
if not os.path.isdir(self.models_path):
return []
files = []
for fname in os.listdir(self.models_path):
path = os.path.join(self.models_path, fname)
if not os.path.isfile(path):
continue
if os.path.splitext(fname)[1].lower() == ".onnx":
files.append(fname)
return sorted(files)
def _load_config(self) -> AutomotiveSLMConfig:
# Prefer assets/config.json if present (for UI defaults and documentation)
assets_root = os.path.dirname(self.models_path) # assets/
cfg_path = os.path.join(assets_root, "config.json")
if os.path.exists(cfg_path):
with open(cfg_path, "r") as f:
cfg = json.load(f)
return AutomotiveSLMConfig(**cfg)
return AutomotiveSLMConfig()
def load_model(self, model_filename: str) -> Tuple[Any, Any, AutomotiveSLMConfig]:
if not isinstance(model_filename, str) or not model_filename:
raise ValueError(f"model_filename must be a non-empty string, got: {model_filename!r}")
if model_filename in self.cache:
return self.cache[model_filename]
model_path = os.path.join(self.models_path, model_filename)
if not os.path.isfile(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# Tokenizer (GPT-2 per your training and inference setup)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ONNX session
providers = ["CPUExecutionProvider"]
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(model_path, providers=providers, sess_options=so)
config = self._load_config()
self.cache[model_filename] = (session, tokenizer, config)
return session, tokenizer, config