Bc-AI commited on
Commit
4ebabd8
Β·
verified Β·
1 Parent(s): 4318cb5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +510 -0
app.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['KERAS_BACKEND'] = 'tensorflow'
3
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
4
+
5
+ import tensorflow as tf
6
+ import keras
7
+ import numpy as np
8
+ from tokenizers import Tokenizer
9
+ from huggingface_hub import hf_hub_download
10
+ import json
11
+ from abc import ABC, abstractmethod
12
+ from fastapi import FastAPI, HTTPException, Request
13
+ from fastapi.responses import StreamingResponse
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import BaseModel
16
+ from typing import List, Optional, AsyncGenerator
17
+ import asyncio
18
+ import gradio as gr
19
+
20
+ # ==============================================================================
21
+ # Model Architecture
22
+ # ==============================================================================
23
+
24
+ @keras.saving.register_keras_serializable()
25
+ class RotaryEmbedding(keras.layers.Layer):
26
+ def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
27
+ super().__init__(**kwargs)
28
+ self.dim = dim
29
+ self.max_len = max_len
30
+ self.theta = theta
31
+ self.built_cache = False
32
+
33
+ def build(self, input_shape):
34
+ if not self.built_cache:
35
+ inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
36
+ t = tf.range(self.max_len, dtype=tf.float32)
37
+ freqs = tf.einsum("i,j->ij", t, inv_freq)
38
+ emb = tf.concat([freqs, freqs], axis=-1)
39
+
40
+ self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
41
+ self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
42
+ self.built_cache = True
43
+ super().build(input_shape)
44
+
45
+ def rotate_half(self, x):
46
+ x1, x2 = tf.split(x, 2, axis=-1)
47
+ return tf.concat([-x2, x1], axis=-1)
48
+
49
+ def call(self, q, k):
50
+ seq_len = tf.shape(q)[2]
51
+ dtype = q.dtype
52
+ cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
53
+ sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
54
+
55
+ q_rotated = (q * cos) + (self.rotate_half(q) * sin)
56
+ k_rotated = (k * cos) + (self.rotate_half(k) * sin)
57
+
58
+ return q_rotated, k_rotated
59
+
60
+ def get_config(self):
61
+ config = super().get_config()
62
+ config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
63
+ return config
64
+
65
+
66
+ @keras.saving.register_keras_serializable()
67
+ class RMSNorm(keras.layers.Layer):
68
+ def __init__(self, epsilon=1e-5, **kwargs):
69
+ super().__init__(**kwargs)
70
+ self.epsilon = epsilon
71
+
72
+ def build(self, input_shape):
73
+ self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
74
+
75
+ def call(self, x):
76
+ variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
77
+ return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
78
+
79
+ def get_config(self):
80
+ config = super().get_config()
81
+ config.update({"epsilon": self.epsilon})
82
+ return config
83
+
84
+
85
+ @keras.saving.register_keras_serializable()
86
+ class TransformerBlock(keras.layers.Layer):
87
+ def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
88
+ super().__init__(**kwargs)
89
+ self.d_model = d_model
90
+ self.n_heads = n_heads
91
+ self.ff_dim = ff_dim
92
+ self.dropout_rate = dropout
93
+ self.max_len = max_len
94
+ self.rope_theta = rope_theta
95
+ self.head_dim = d_model // n_heads
96
+ self.layer_idx = layer_idx
97
+
98
+ self.pre_attn_norm = RMSNorm()
99
+ self.pre_ffn_norm = RMSNorm()
100
+
101
+ self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
102
+ self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
103
+ self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
104
+ self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
105
+
106
+ self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
107
+
108
+ self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
109
+ self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
110
+ self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
111
+
112
+ self.dropout = keras.layers.Dropout(dropout)
113
+
114
+ def call(self, x, training=None):
115
+ B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
116
+ dtype = x.dtype
117
+
118
+ res = x
119
+ y = self.pre_attn_norm(x)
120
+
121
+ q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
122
+ k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
123
+ v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
124
+
125
+ q, k = self.rope(q, k)
126
+
127
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
128
+
129
+ mask = tf.where(
130
+ tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
131
+ tf.constant(-1e9, dtype=dtype),
132
+ tf.constant(0.0, dtype=dtype)
133
+ )
134
+ scores += mask
135
+ attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
136
+
137
+ attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
138
+ x = res + self.dropout(self.out_proj(attn), training=training)
139
+
140
+ res = x
141
+ y = self.pre_ffn_norm(x)
142
+ ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
143
+
144
+ return res + self.dropout(ffn, training=training)
145
+
146
+ def get_config(self):
147
+ config = super().get_config()
148
+ config.update({
149
+ "d_model": self.d_model,
150
+ "n_heads": self.n_heads,
151
+ "ff_dim": self.ff_dim,
152
+ "dropout": self.dropout_rate,
153
+ "max_len": self.max_len,
154
+ "rope_theta": self.rope_theta,
155
+ "layer_idx": self.layer_idx
156
+ })
157
+ return config
158
+
159
+
160
+ @keras.saving.register_keras_serializable()
161
+ class SAM1Model(keras.Model):
162
+ def __init__(self, **kwargs):
163
+ super().__init__()
164
+ if 'config' in kwargs and isinstance(kwargs['config'], dict):
165
+ self.cfg = kwargs['config']
166
+ elif 'vocab_size' in kwargs:
167
+ self.cfg = kwargs
168
+ else:
169
+ self.cfg = kwargs.get('cfg', kwargs)
170
+
171
+ self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
172
+
173
+ ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
174
+ block_args = {
175
+ 'd_model': self.cfg['d_model'],
176
+ 'n_heads': self.cfg['n_heads'],
177
+ 'ff_dim': ff_num,
178
+ 'dropout': self.cfg['dropout'],
179
+ 'max_len': self.cfg['max_len'],
180
+ 'rope_theta': self.cfg['rope_theta']
181
+ }
182
+
183
+ self.blocks = []
184
+ for i in range(self.cfg['n_layers']):
185
+ block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
186
+ self.blocks.append(block)
187
+
188
+ self.norm = RMSNorm(name="final_norm")
189
+ self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
190
+
191
+ def call(self, input_ids, training=None):
192
+ x = self.embed(input_ids)
193
+ for block in self.blocks:
194
+ x = block(x, training=training)
195
+ return self.lm_head(self.norm(x))
196
+
197
+ def get_config(self):
198
+ base_config = super().get_config()
199
+ base_config['config'] = self.cfg
200
+ return base_config
201
+
202
+
203
+ # ==============================================================================
204
+ # Helper: Parameter Counting
205
+ # ==============================================================================
206
+
207
+ def count_parameters(model):
208
+ total_params = 0
209
+ non_zero_params = 0
210
+ for weight in model.weights:
211
+ w = weight.numpy()
212
+ total_params += w.size
213
+ non_zero_params += np.count_nonzero(w)
214
+ return total_params, non_zero_params
215
+
216
+ def format_param_count(count):
217
+ if count >= 1e9:
218
+ return f"{count/1e9:.2f}B"
219
+ elif count >= 1e6:
220
+ return f"{count/1e6:.2f}M"
221
+ elif count >= 1e3:
222
+ return f"{count/1e3:.2f}K"
223
+ else:
224
+ return str(count)
225
+
226
+
227
+ # ==============================================================================
228
+ # Backend Interface
229
+ # ==============================================================================
230
+
231
+ class ModelBackend(ABC):
232
+ @abstractmethod
233
+ def predict(self, input_ids): pass
234
+ @abstractmethod
235
+ def get_name(self): pass
236
+ @abstractmethod
237
+ def get_info(self): pass
238
+
239
+ class KerasBackend(ModelBackend):
240
+ def __init__(self, model, name, display_name):
241
+ self.model = model
242
+ self.name = name
243
+ self.display_name = display_name
244
+ total, non_zero = count_parameters(model)
245
+ self.total_params = total
246
+ self.non_zero_params = non_zero
247
+ self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0
248
+ self.n_heads = model.cfg.get('n_heads', 0)
249
+ self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
250
+
251
+ def predict(self, input_ids):
252
+ inputs = np.array([input_ids], dtype=np.int32)
253
+ logits = self.model(inputs, training=False)
254
+ return logits[0, -1, :].numpy()
255
+
256
+ def get_name(self):
257
+ return self.display_name
258
+
259
+ def get_info(self):
260
+ info = f"{self.display_name}\n"
261
+ info += f" Total params: {format_param_count(self.total_params)}\n"
262
+ info += f" Attention heads: {self.n_heads}\n"
263
+ info += f" FFN dimension: {self.ff_dim}\n"
264
+ if self.sparsity > 1:
265
+ info += f" Sparsity: {self.sparsity:.1f}%\n"
266
+ return info
267
+
268
+
269
+ # ==============================================================================
270
+ # Load Models & Tokenizer
271
+ # ==============================================================================
272
+
273
+ CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
274
+
275
+ print("="*60)
276
+ print("πŸš€ SAM-X-1 Hybrid API + UI Loading...".center(60))
277
+ print("="*60)
278
+
279
+ # Download config/tokenizer
280
+ print(f"πŸ“¦ Fetching config & tokenizer from {CONFIG_TOKENIZER_REPO_ID}")
281
+ config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
282
+ tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json")
283
+
284
+ with open(config_path, 'r') as f:
285
+ base_config = json.load(f)
286
+
287
+ base_model_config = {
288
+ 'vocab_size': base_config['vocab_size'],
289
+ 'd_model': base_config['hidden_size'],
290
+ 'n_heads': base_config['num_attention_heads'],
291
+ 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'],
292
+ 'dropout': base_config.get('dropout', 0.0),
293
+ 'max_len': base_config['max_position_embeddings'],
294
+ 'rope_theta': base_config['rope_theta'],
295
+ 'n_layers': base_config['num_hidden_layers']
296
+ }
297
+
298
+ print("πŸ”€ Building tokenizer...")
299
+ tokenizer = Tokenizer.from_pretrained("gpt2")
300
+ eos_token = ""
301
+ eos_token_id = tokenizer.token_to_id(eos_token)
302
+ if eos_token_id is None:
303
+ tokenizer.add_special_tokens([eos_token])
304
+ eos_token_id = tokenizer.token_to_id(eos_token)
305
+
306
+ custom_tokens = ["<think>", "<think/>"]
307
+ for token in custom_tokens:
308
+ if tokenizer.token_to_id(token) is None:
309
+ tokenizer.add_special_tokens([token])
310
+
311
+ tokenizer.no_padding()
312
+ tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
313
+ print("βœ… Tokenizer ready")
314
+
315
+ # Model Registry
316
+ MODEL_REGISTRY = [
317
+ ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
318
+ ("SAM-X-1-Fast ⚑ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"),
319
+ ("SAM-X-1-Mini πŸš€ (BETA)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini.weights.h5", "sam1_mini_config.json"),
320
+ ("SAM-X-1-Nano ⚑⚑ (BETA)", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano.weights.h5", "sam1_nano_config.json"),
321
+ ]
322
+
323
+ available_models = {}
324
+ dummy_input = tf.zeros((1, 1), dtype=tf.int32)
325
+
326
+ for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
327
+ try:
328
+ print(f"\nπŸ“₯ Loading {display_name}...")
329
+ weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
330
+
331
+ model_config = base_model_config.copy()
332
+ if config_filename:
333
+ print(f" Custom config: {config_filename}")
334
+ custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
335
+ with open(custom_config_path, 'r') as f:
336
+ model_config.update(json.load(f))
337
+
338
+ model = SAM1Model(**model_config)
339
+ model(dummy_input)
340
+ model.load_weights(weights_path)
341
+ model.trainable = False
342
+
343
+ backend = KerasBackend(model, display_name, display_name)
344
+ available_models[display_name] = backend
345
+
346
+ print(f"βœ… Loaded: {display_name}")
347
+ print(f" β†’ Params: {format_param_count(backend.total_params)} | Heads: {backend.n_heads}")
348
+
349
+ except Exception as e:
350
+ print(f"❌ Failed to load {display_name}: {e}")
351
+
352
+ if not available_models:
353
+ raise RuntimeError("No models loaded!")
354
+
355
+ current_backend = list(available_models.values())[0]
356
+ print(f"\nπŸŽ‰ Ready! Default model: {current_backend.get_name()}")
357
+
358
+
359
+ # ==============================================================================
360
+ # Streaming Generator
361
+ # ==============================================================================
362
+
363
+ async def generate_stream(prompt: str, backend, temperature: float) -> AsyncGenerator[str]:
364
+ encoded_prompt = tokenizer.encode(prompt)
365
+ input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
366
+ generated = input_ids.copy()
367
+ max_len = backend.model.cfg['max_len']
368
+ buffer = ""
369
+
370
+ for _ in range(512):
371
+ await asyncio.sleep(0)
372
+ current_input = generated[-max_len:]
373
+ next_token_logits = backend.predict(current_input)
374
+
375
+ if temperature > 0:
376
+ next_token_logits /= temperature
377
+ top_k_indices = np.argpartition(next_token_logits, -50)[-50:]
378
+ top_k_logits = next_token_logits[top_k_indices]
379
+ top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
380
+ top_k_probs /= top_k_probs.sum()
381
+ next_token = np.random.choice(top_k_indices, p=top_k_probs)
382
+ else:
383
+ next_token = int(np.argmax(next_token_logits))
384
+
385
+ if next_token == eos_token_id:
386
+ break
387
+
388
+ generated.append(int(next_token))
389
+ new_text = tokenizer.decode(generated[len(input_ids):])
390
+ if len(new_text) > len(buffer):
391
+ new_chunk = new_text[len(buffer):]
392
+ buffer = new_text
393
+ yield new_chunk
394
+
395
+
396
+ # ==============================================================================
397
+ # Gradio Chat Function
398
+ # ==============================================================================
399
+
400
+ def chat_fn(message, history, model_choice="SAM-X-1-Large", temperature=0.7):
401
+ backend = available_models[model_choice]
402
+ prompt = f"User: {message}\nSam: <think>"
403
+ response = ""
404
+ for chunk in generate_stream(prompt, backend, temperature):
405
+ response += chunk
406
+ yield response
407
+
408
+
409
+ # ==============================================================================
410
+ # FastAPI Endpoints (OpenAI-style)
411
+ # ==============================================================================
412
+
413
+ class Message(BaseModel):
414
+ role: str
415
+ content: str
416
+
417
+ class ChatCompletionRequest(BaseModel):
418
+ model: str = list(available_models.keys())[0]
419
+ messages: List[Message]
420
+ temperature: float = 0.7
421
+ stream: bool = False
422
+ max_tokens: int = 512
423
+
424
+ app = FastAPI()
425
+
426
+ app.add_middleware(
427
+ CORSMiddleware,
428
+ allow_origins=["*"],
429
+ allow_credentials=True,
430
+ allow_methods=["*"],
431
+ allow_headers=["*"],
432
+ )
433
+
434
+ @app.post("/v1/chat/completions")
435
+ async def chat_completions(request: ChatCompletionRequest):
436
+ if request.model not in available_models:
437
+ raise HTTPException(404, f"Model '{request.model}' not found.")
438
+
439
+ backend = available_models[request.model]
440
+
441
+ prompt_parts = []
442
+ for msg in request.messages:
443
+ prefix = "User" if msg.role.lower() == "user" else "Sam"
444
+ prompt_parts.append(f"{prefix}: {msg.content}")
445
+ prompt_parts.append("Sam: <think>")
446
+ prompt = "\n".join(prompt_parts)
447
+
448
+ async def event_stream():
449
+ async for token in generate_stream(prompt, backend, request.temperature):
450
+ chunk = {
451
+ "id": "chatcmpl-123",
452
+ "object": "chat.completion.chunk",
453
+ "created": 1677858242,
454
+ "model": request.model,
455
+ "choices": [{
456
+ "index": 0,
457
+ "delta": {"content": token},
458
+ "finish_reason": None
459
+ }]
460
+ }
461
+ yield f"data: {json.dumps(chunk)}\n\n"
462
+ yield "data: [DONE]\n\n"
463
+
464
+ if request.stream:
465
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
466
+ else:
467
+ full = ""
468
+ async for token in event_stream():
469
+ if b"[DONE]" not in token.encode():
470
+ data = json.loads(token.replace("data: ", "").strip())
471
+ full += data["choices"][0]["delta"]["content"]
472
+ return {"choices": [{"message": {"content": full}}]}
473
+
474
+ @app.get("/v1/models")
475
+ async def list_models():
476
+ return {
477
+ "data": [
478
+ {"id": name, "object": "model", "owned_by": "SmilyAI"}
479
+ for name in available_models.keys()
480
+ ]
481
+ }
482
+
483
+
484
+ # ==============================================================================
485
+ # Gradio UI
486
+ # ==============================================================================
487
+
488
+ with gr.Blocks(title="SAM-X-1 Chat", theme=gr.themes.Soft()) as demo:
489
+ gr.Markdown("# πŸ€– SAM-X-1 Multi-Model Chat")
490
+
491
+ with gr.Row():
492
+ with gr.Column(scale=4):
493
+ chat = gr.ChatInterface(
494
+ fn=chat_fn,
495
+ additional_inputs=[
496
+ gr.Dropdown(
497
+ choices=list(available_models.keys()),
498
+ value=list(available_models.keys())[0],
499
+ label="Model"
500
+ ),
501
+ gr.Slider(0.0, 2.0, value=0.7, label="Temperature")
502
+ ],
503
+ examples=[
504
+ "Explain quantum computing like I'm 5.",
505
+ "Write a haiku about a robot learning to dream."
506
+ ]
507
+ )
508
+
509
+ # Mount Gradio app on root
510
+ app = gr.mount_gradio_app(app, demo, path="/")