Update src/model_manager.py
Browse files- src/model_manager.py +31 -22
src/model_manager.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
| 1 |
import os
|
| 2 |
-
import torch
|
| 3 |
import json
|
| 4 |
-
import onnxruntime as ort
|
| 5 |
-
from transformers import AutoTokenizer
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from typing import Tuple, Any, List
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
@dataclass
|
| 10 |
class AutomotiveSLMConfig:
|
| 11 |
model_name: str = "Automotive-SLM-Edge-3M"
|
|
@@ -20,15 +22,17 @@ class AutomotiveSLMConfig:
|
|
| 20 |
moe_intermediate_size: int = 384
|
| 21 |
router_aux_loss_coef: float = 0.01
|
| 22 |
rotary_dim: int = 64
|
| 23 |
-
rope_base: float = 10000
|
| 24 |
dropout: float = 0.05
|
| 25 |
layer_norm_epsilon: float = 1e-5
|
|
|
|
| 26 |
max_gen_length: int = 50
|
| 27 |
temperature: float = 0.8
|
| 28 |
top_p: float = 0.9
|
| 29 |
top_k: int = 50
|
| 30 |
repetition_penalty: float = 1.1
|
| 31 |
|
|
|
|
| 32 |
class ModelManager:
|
| 33 |
def __init__(self, models_path: str):
|
| 34 |
if not isinstance(models_path, str) or not models_path:
|
|
@@ -41,32 +45,35 @@ class ModelManager:
|
|
| 41 |
if not os.path.isdir(self.models_path):
|
| 42 |
return []
|
| 43 |
files = []
|
| 44 |
-
for
|
| 45 |
-
path = os.path.join(self.models_path,
|
| 46 |
if not os.path.isfile(path):
|
| 47 |
continue
|
| 48 |
-
ext = os.path.splitext(
|
| 49 |
if ext in [".pt", ".pth", ".onnx"]:
|
| 50 |
-
files.append(
|
| 51 |
return sorted(files)
|
| 52 |
|
| 53 |
def _load_config(self, checkpoint_path: str) -> AutomotiveSLMConfig:
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
cfg_path = os.path.join(assets_root, "config.json")
|
| 59 |
-
if
|
| 60 |
with open(cfg_path, "r") as f:
|
| 61 |
cfg = json.load(f)
|
| 62 |
return AutomotiveSLMConfig(**cfg)
|
| 63 |
-
|
|
|
|
| 64 |
ext = os.path.splitext(checkpoint_path)[1].lower()
|
| 65 |
if ext in [".pt", ".pth"] and os.path.exists(checkpoint_path):
|
| 66 |
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
| 67 |
if isinstance(ckpt, dict) and "config" in ckpt:
|
| 68 |
return AutomotiveSLMConfig(**ckpt["config"])
|
| 69 |
-
|
|
|
|
| 70 |
return AutomotiveSLMConfig()
|
| 71 |
|
| 72 |
def load_model(self, model_filename: str) -> Tuple[Any, Any, AutomotiveSLMConfig]:
|
|
@@ -80,7 +87,7 @@ class ModelManager:
|
|
| 80 |
if not os.path.isfile(model_path):
|
| 81 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 82 |
|
| 83 |
-
# tokenizer
|
| 84 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 85 |
if tokenizer.pad_token is None:
|
| 86 |
tokenizer.pad_token = tokenizer.eos_token
|
|
@@ -89,20 +96,22 @@ class ModelManager:
|
|
| 89 |
config = self._load_config(model_path)
|
| 90 |
|
| 91 |
if ext in [".pt", ".pth"]:
|
|
|
|
| 92 |
from src.model_architecture import AutomotiveSLM
|
| 93 |
-
|
| 94 |
-
model = AutomotiveSLM(config)
|
| 95 |
state = checkpoint.get("model_state_dict", checkpoint)
|
|
|
|
| 96 |
model.load_state_dict(state, strict=True)
|
| 97 |
model.eval()
|
|
|
|
| 98 |
elif ext == ".onnx":
|
| 99 |
providers = ["CPUExecutionProvider"]
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
model = ort.InferenceSession(model_path,
|
|
|
|
| 103 |
else:
|
| 104 |
raise ValueError(f"Unsupported model format: {ext}")
|
| 105 |
|
| 106 |
self.cache[model_filename] = (model, tokenizer, config)
|
| 107 |
return model, tokenizer, config
|
| 108 |
-
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import json
|
|
|
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import Tuple, Any, List
|
| 5 |
|
| 6 |
+
import torch
|
| 7 |
+
import onnxruntime as ort
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
@dataclass
|
| 12 |
class AutomotiveSLMConfig:
|
| 13 |
model_name: str = "Automotive-SLM-Edge-3M"
|
|
|
|
| 22 |
moe_intermediate_size: int = 384
|
| 23 |
router_aux_loss_coef: float = 0.01
|
| 24 |
rotary_dim: int = 64
|
| 25 |
+
rope_base: float = 10000.0
|
| 26 |
dropout: float = 0.05
|
| 27 |
layer_norm_epsilon: float = 1e-5
|
| 28 |
+
# UI defaults
|
| 29 |
max_gen_length: int = 50
|
| 30 |
temperature: float = 0.8
|
| 31 |
top_p: float = 0.9
|
| 32 |
top_k: int = 50
|
| 33 |
repetition_penalty: float = 1.1
|
| 34 |
|
| 35 |
+
|
| 36 |
class ModelManager:
|
| 37 |
def __init__(self, models_path: str):
|
| 38 |
if not isinstance(models_path, str) or not models_path:
|
|
|
|
| 45 |
if not os.path.isdir(self.models_path):
|
| 46 |
return []
|
| 47 |
files = []
|
| 48 |
+
for fname in os.listdir(self.models_path):
|
| 49 |
+
path = os.path.join(self.models_path, fname)
|
| 50 |
if not os.path.isfile(path):
|
| 51 |
continue
|
| 52 |
+
ext = os.path.splitext(fname)[1].lower()
|
| 53 |
if ext in [".pt", ".pth", ".onnx"]:
|
| 54 |
+
files.append(fname)
|
| 55 |
return sorted(files)
|
| 56 |
|
| 57 |
def _load_config(self, checkpoint_path: str) -> AutomotiveSLMConfig:
|
| 58 |
+
if not isinstance(checkpoint_path, str) or not checkpoint_path:
|
| 59 |
+
raise ValueError(f"checkpoint_path must be a non-empty string, got: {checkpoint_path!r}")
|
| 60 |
+
|
| 61 |
+
# Prefer assets/config.json
|
| 62 |
+
assets_root = os.path.dirname(self.models_path) # assets/
|
| 63 |
cfg_path = os.path.join(assets_root, "config.json")
|
| 64 |
+
if os.path.exists(cfg_path):
|
| 65 |
with open(cfg_path, "r") as f:
|
| 66 |
cfg = json.load(f)
|
| 67 |
return AutomotiveSLMConfig(**cfg)
|
| 68 |
+
|
| 69 |
+
# Fallback: read config from torch checkpoint if present
|
| 70 |
ext = os.path.splitext(checkpoint_path)[1].lower()
|
| 71 |
if ext in [".pt", ".pth"] and os.path.exists(checkpoint_path):
|
| 72 |
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
| 73 |
if isinstance(ckpt, dict) and "config" in ckpt:
|
| 74 |
return AutomotiveSLMConfig(**ckpt["config"])
|
| 75 |
+
|
| 76 |
+
# Last resort defaults
|
| 77 |
return AutomotiveSLMConfig()
|
| 78 |
|
| 79 |
def load_model(self, model_filename: str) -> Tuple[Any, Any, AutomotiveSLMConfig]:
|
|
|
|
| 87 |
if not os.path.isfile(model_path):
|
| 88 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 89 |
|
| 90 |
+
# Load tokenizer (GPT-2 per training)
|
| 91 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 92 |
if tokenizer.pad_token is None:
|
| 93 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
| 96 |
config = self._load_config(model_path)
|
| 97 |
|
| 98 |
if ext in [".pt", ".pth"]:
|
| 99 |
+
# Import only when needed to avoid circular deps
|
| 100 |
from src.model_architecture import AutomotiveSLM
|
| 101 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
|
|
|
| 102 |
state = checkpoint.get("model_state_dict", checkpoint)
|
| 103 |
+
model = AutomotiveSLM(config)
|
| 104 |
model.load_state_dict(state, strict=True)
|
| 105 |
model.eval()
|
| 106 |
+
|
| 107 |
elif ext == ".onnx":
|
| 108 |
providers = ["CPUExecutionProvider"]
|
| 109 |
+
sess_options = ort.SessionOptions()
|
| 110 |
+
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 111 |
+
model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
| 112 |
+
|
| 113 |
else:
|
| 114 |
raise ValueError(f"Unsupported model format: {ext}")
|
| 115 |
|
| 116 |
self.cache[model_filename] = (model, tokenizer, config)
|
| 117 |
return model, tokenizer, config
|
|
|