Nihal2000 commited on
Commit
a45df4e
·
verified ·
1 Parent(s): f1e72e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -78
app.py CHANGED
@@ -1,56 +1,89 @@
1
  import os
2
  import gradio as gr
 
 
 
3
 
4
- from src.model_manager import ModelManager
5
- from src.inference_engine import InferenceEngine
6
-
7
- ASSETS_DIR = "assets"
8
- MODELS_DIR = os.path.join(ASSETS_DIR, "models")
9
-
10
- os.makedirs(ASSETS_DIR, exist_ok=True)
11
- os.makedirs(MODELS_DIR, exist_ok=True)
12
-
13
- manager = ModelManager(MODELS_DIR)
14
- _ENGINE_CACHE = {}
15
-
16
- def list_models():
17
- return manager.get_available_models()
18
-
19
- def load_engine(model_name: str) -> InferenceEngine:
20
- if model_name in _ENGINE_CACHE:
21
- return _ENGINE_CACHE[model_name]
22
- session, tokenizer, config = manager.load_model(model_name)
23
- engine = InferenceEngine(session, tokenizer, config)
24
- _ENGINE_CACHE[model_name] = engine
25
- return engine
26
-
27
- def chat_fn(message, history, model_name, max_tokens, temperature, top_p, top_k):
28
- if not model_name:
29
- history = history + [{"role": "assistant", "content": "No model selected. Please choose an ONNX model."}]
30
- return history
31
- try:
32
- engine = load_engine(model_name)
33
- reply = engine.generate_response(
34
- message,
35
- max_tokens=int(max_tokens),
36
- temperature=float(temperature),
37
- top_p=float(top_p),
38
- top_k=int(top_k),
39
- )
40
- except Exception as e:
41
- reply = f"Error during inference: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  history = history + [
43
  {"role": "user", "content": message},
44
  {"role": "assistant", "content": reply},
45
  ]
46
  return history
47
 
48
- def clear_chat():
49
- return []
50
-
51
  with gr.Blocks(title="Automotive SLM Chatbot (ONNX)") as demo:
52
  gr.Markdown("# 🚗 Automotive SLM Chatbot (ONNX-only)")
53
- gr.Markdown("Place your .onnx models in assets/models and select one to chat.")
54
 
55
  with gr.Row():
56
  with gr.Column(scale=3):
@@ -61,41 +94,23 @@ with gr.Blocks(title="Automotive SLM Chatbot (ONNX)") as demo:
61
  clear_btn = gr.Button("Clear")
62
 
63
  with gr.Column(scale=2):
64
- gr.Markdown("### Model settings")
65
- available = list_models()
66
- if not available:
67
- gr.Markdown("No ONNX models found in assets/models. Please add .onnx files and refresh.")
68
- model_dropdown = gr.Dropdown(choices=[], value=None, label="Model", interactive=False)
69
- max_tokens = gr.Slider(10, 256, value=64, step=1, label="Max tokens", interactive=False)
70
- temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature", interactive=False)
71
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p", interactive=False)
72
- top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k", interactive=False)
73
- else:
74
- # Optional labels with size
75
- def size_mb(path):
76
- try: return os.path.getsize(path) / (1024 * 1024)
77
- except Exception: return 0.0
78
- labels = [f"{n} ({size_mb(os.path.join(MODELS_DIR, n)):.1f} MB)" for n in available]
79
- choices = list(zip(labels, available))
80
-
81
- model_dropdown = gr.Dropdown(choices=choices, value=available[0], label="Model")
82
- max_tokens = gr.Slider(10, 256, value=64, step=1, label="Max tokens")
83
- temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature")
84
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
85
- top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
86
-
87
- if available:
88
- send_btn.click(
89
- fn=chat_fn,
90
- inputs=[msg, chatbot, model_dropdown, max_tokens, temperature, top_p, top_k],
91
- outputs=[chatbot]
92
- )
93
- msg.submit(
94
- fn=chat_fn,
95
- inputs=[msg, chatbot, model_dropdown, max_tokens, temperature, top_p, top_k],
96
- outputs=[chatbot]
97
- )
98
- clear_btn.click(clear_chat, None, chatbot)
99
 
100
  if __name__ == "__main__":
101
- demo.launch()
 
1
  import os
2
  import gradio as gr
3
+ import onnxruntime as ort
4
+ import numpy as np
5
+ from transformers import AutoTokenizer
6
 
7
+ ONNX_PATH = os.path.join("assets", "automotive_slm.onnx")
8
+
9
+ # Load tokenizer (must match training tokenizer)
10
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
11
+ if tokenizer.pad_token is None:
12
+ tokenizer.pad_token = tokenizer.eos_token
13
+
14
+ # Create ONNX session
15
+ providers = ["CPUExecutionProvider"]
16
+ so = ort.SessionOptions()
17
+ so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
18
+ session = ort.InferenceSession(ONNX_PATH, providers=providers, sess_options=so)
19
+
20
+ # Infer IO names
21
+ INPUT_NAME = session.get_inputs()[0].name
22
+ OUTPUT_NAME = session.get_outputs()[0].name
23
+
24
+ def generate_onnx(prompt: str, max_tokens=64, temperature=0.8, top_p=0.9, top_k=50) -> str:
25
+ tokens = tokenizer.encode(prompt)
26
+ input_ids = np.array([tokens], dtype=np.int64)
27
+ generated = []
28
+
29
+ for _ in range(int(max_tokens)):
30
+ outputs = session.run([OUTPUT_NAME], {INPUT_NAME: input_ids})
31
+ logits = outputs[0][0, -1, :]
32
+
33
+ # Temperature
34
+ if temperature and temperature > 0:
35
+ logits = logits / max(float(temperature), 1e-6)
36
+
37
+ # Top-k
38
+ if top_k and int(top_k) > 0:
39
+ k = min(int(top_k), logits.shape[-1])
40
+ idx = np.argpartition(logits, -k)[-k:]
41
+ filt = np.full_like(logits, -np.inf)
42
+ filt[idx] = logits[idx]
43
+ logits = filt
44
+
45
+ # Softmax
46
+ exps = np.exp(logits - np.max(logits))
47
+ probs = exps / np.sum(exps)
48
+
49
+ # Top-p
50
+ if top_p is not None and 0 < float(top_p) < 1.0:
51
+ sort_idx = np.argsort(probs)[::-1]
52
+ sorted_probs = probs[sort_idx]
53
+ cumsum = np.cumsum(sorted_probs)
54
+ cutoff = np.searchsorted(cumsum, float(top_p)) + 1
55
+ mask = np.zeros_like(probs)
56
+ keep = sort_idx[:cutoff]
57
+ mask[keep] = probs[keep]
58
+ s = mask.sum()
59
+ if s > 0:
60
+ probs = mask / s
61
+
62
+ next_token = int(np.random.choice(len(probs), p=probs))
63
+ if next_token == tokenizer.eos_token_id:
64
+ break
65
+
66
+ generated.append(next_token)
67
+ input_ids = np.concatenate([input_ids, [[next_token]]], axis=1)
68
+
69
+ text = tokenizer.decode(generated, skip_special_tokens=True).strip()
70
+ if not text:
71
+ return "I couldn't generate a response."
72
+ if text.startswith(prompt):
73
+ text = text[len(prompt):].strip()
74
+ return text
75
+
76
+ def chat_fn(message, history, max_tokens, temperature, top_p, top_k):
77
+ reply = generate_onnx(message, max_tokens, temperature, top_p, top_k)
78
  history = history + [
79
  {"role": "user", "content": message},
80
  {"role": "assistant", "content": reply},
81
  ]
82
  return history
83
 
 
 
 
84
  with gr.Blocks(title="Automotive SLM Chatbot (ONNX)") as demo:
85
  gr.Markdown("# 🚗 Automotive SLM Chatbot (ONNX-only)")
86
+ gr.Markdown("Using model at assets/automotive_slm.onnx")
87
 
88
  with gr.Row():
89
  with gr.Column(scale=3):
 
94
  clear_btn = gr.Button("Clear")
95
 
96
  with gr.Column(scale=2):
97
+ gr.Markdown("### Generation settings")
98
+ max_tokens = gr.Slider(10, 256, value=64, step=1, label="Max tokens")
99
+ temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature")
100
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
101
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
102
+
103
+ send_btn.click(
104
+ fn=chat_fn,
105
+ inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k],
106
+ outputs=[chatbot]
107
+ )
108
+ msg.submit(
109
+ fn=chat_fn,
110
+ inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k],
111
+ outputs=[chatbot]
112
+ )
113
+ clear_btn.click(lambda: [], None, chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  if __name__ == "__main__":
116
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))