Bc-AI commited on
Commit
f59a624
Β·
verified Β·
1 Parent(s): d0f7483

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +632 -0
app.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM-Z-1 Worker Node - Complete Implementation
3
+ Loads model and processes generation requests
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.responses import StreamingResponse
8
+ from pydantic import BaseModel
9
+ import tensorflow as tf
10
+ import keras
11
+ from huggingface_hub import hf_hub_download
12
+ import json
13
+ import os
14
+ from tokenizers import Tokenizer
15
+ import numpy as np
16
+ import time
17
+ from typing import List
18
+ import asyncio
19
+
20
+ app = FastAPI(title="SAM-Z-1 Worker", version="1.0.0")
21
+
22
+ # ============================================================================
23
+ # Model Architecture Definitions
24
+ # ============================================================================
25
+
26
+ @keras.saving.register_keras_serializable()
27
+ class RotaryEmbedding(keras.layers.Layer):
28
+ def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
29
+ super().__init__(**kwargs)
30
+ self.dim = dim
31
+ self.max_len = max_len
32
+ self.theta = theta
33
+ self.built_cache = False
34
+
35
+ def build(self, input_shape):
36
+ super().build(input_shape)
37
+
38
+ def _build_cache(self):
39
+ """Build RoPE cache on first forward pass"""
40
+ if not self.built_cache:
41
+ inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
42
+ t = tf.range(self.max_len, dtype=tf.float32)
43
+ freqs = tf.einsum("i,j->ij", t, inv_freq)
44
+ emb = tf.concat([freqs, freqs], axis=-1)
45
+
46
+ self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
47
+ self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
48
+ self.built_cache = True
49
+
50
+ def rotate_half(self, x):
51
+ x1, x2 = tf.split(x, 2, axis=-1)
52
+ return tf.concat([-x2, x1], axis=-1)
53
+
54
+ def call(self, q, k):
55
+ self._build_cache()
56
+
57
+ seq_len = tf.shape(q)[2]
58
+ dtype = q.dtype
59
+ cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
60
+ sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
61
+
62
+ q_rotated = (q * cos) + (self.rotate_half(q) * sin)
63
+ k_rotated = (k * cos) + (self.rotate_half(k) * sin)
64
+
65
+ return q_rotated, k_rotated
66
+
67
+ def get_config(self):
68
+ config = super().get_config()
69
+ config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
70
+ return config
71
+
72
+
73
+ @keras.saving.register_keras_serializable()
74
+ class RMSNorm(keras.layers.Layer):
75
+ def __init__(self, epsilon=1e-5, **kwargs):
76
+ super().__init__(**kwargs)
77
+ self.epsilon = epsilon
78
+
79
+ def build(self, input_shape):
80
+ self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
81
+
82
+ def call(self, x):
83
+ variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
84
+ return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
85
+
86
+ def get_config(self):
87
+ config = super().get_config()
88
+ config.update({"epsilon": self.epsilon})
89
+ return config
90
+
91
+
92
+ @keras.saving.register_keras_serializable()
93
+ class TransformerBlock(keras.layers.Layer):
94
+ def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
95
+ super().__init__(**kwargs)
96
+ self.d_model = d_model
97
+ self.n_heads = n_heads
98
+ self.ff_dim = ff_dim
99
+ self.dropout_rate = dropout
100
+ self.max_len = max_len
101
+ self.rope_theta = rope_theta
102
+ self.head_dim = d_model // n_heads
103
+ self.layer_idx = layer_idx
104
+
105
+ self.pre_attn_norm = RMSNorm()
106
+ self.pre_ffn_norm = RMSNorm()
107
+
108
+ self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
109
+ self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
110
+ self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
111
+ self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
112
+
113
+ self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
114
+
115
+ self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
116
+ self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
117
+ self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
118
+
119
+ self.dropout = keras.layers.Dropout(dropout)
120
+
121
+ def call(self, x, training=None):
122
+ B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
123
+ dtype = x.dtype
124
+
125
+ # Attention
126
+ res = x
127
+ y = self.pre_attn_norm(x)
128
+
129
+ q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
130
+ k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
131
+ v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
132
+
133
+ q, k = self.rope(q, k)
134
+
135
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
136
+
137
+ mask = tf.where(
138
+ tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
139
+ tf.constant(-1e9, dtype=dtype),
140
+ tf.constant(0.0, dtype=dtype)
141
+ )
142
+ scores += mask
143
+ attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
144
+
145
+ attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
146
+ x = res + self.dropout(self.out_proj(attn), training=training)
147
+
148
+ # FFN (SwiGLU)
149
+ res = x
150
+ y = self.pre_ffn_norm(x)
151
+ ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
152
+
153
+ return res + self.dropout(ffn, training=training)
154
+
155
+ def get_config(self):
156
+ config = super().get_config()
157
+ config.update({
158
+ "d_model": self.d_model,
159
+ "n_heads": self.n_heads,
160
+ "ff_dim": self.ff_dim,
161
+ "dropout": self.dropout_rate,
162
+ "max_len": self.max_len,
163
+ "rope_theta": self.rope_theta,
164
+ "layer_idx": self.layer_idx
165
+ })
166
+ return config
167
+
168
+
169
+ @keras.saving.register_keras_serializable()
170
+ class SAM1Model(keras.Model):
171
+ def __init__(self, **kwargs):
172
+ super().__init__()
173
+ if 'config' in kwargs and isinstance(kwargs['config'], dict):
174
+ self.cfg = kwargs['config']
175
+ elif 'vocab_size' in kwargs:
176
+ self.cfg = kwargs
177
+ else:
178
+ self.cfg = kwargs.get('cfg', kwargs)
179
+
180
+ self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
181
+
182
+ ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
183
+ block_args = {
184
+ 'd_model': self.cfg['d_model'],
185
+ 'n_heads': self.cfg['n_heads'],
186
+ 'ff_dim': ff_dim,
187
+ 'dropout': self.cfg['dropout'],
188
+ 'max_len': self.cfg['max_len'],
189
+ 'rope_theta': self.cfg['rope_theta']
190
+ }
191
+
192
+ self.blocks = []
193
+ for i in range(self.cfg['n_layers']):
194
+ block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
195
+ self.blocks.append(block)
196
+
197
+ self.norm = RMSNorm(name="final_norm")
198
+ self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
199
+
200
+ def call(self, input_ids, training=None):
201
+ x = self.embed(input_ids)
202
+
203
+ for block in self.blocks:
204
+ x = block(x, training=training)
205
+
206
+ return self.lm_head(self.norm(x))
207
+
208
+ def get_config(self):
209
+ base_config = super().get_config()
210
+ base_config['config'] = self.cfg
211
+ return base_config
212
+
213
+ # ============================================================================
214
+ # Global Variables
215
+ # ============================================================================
216
+
217
+ model = None
218
+ tokenizer = None
219
+ config = None
220
+ eos_token_id = None
221
+ fast_forward = None
222
+
223
+ MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
224
+ CACHE_DIR = "./model_cache"
225
+
226
+ # ============================================================================
227
+ # Request Models
228
+ # ============================================================================
229
+
230
+ class GenerateRequest(BaseModel):
231
+ prompt: str
232
+ max_tokens: int = 512
233
+ temperature: float = 0.8
234
+ top_k: int = 40
235
+ top_p: float = 0.9
236
+ repetition_penalty: float = 1.1
237
+ stream: bool = False
238
+
239
+ class ChatMessage(BaseModel):
240
+ role: str
241
+ content: str
242
+
243
+ class ChatRequest(BaseModel):
244
+ messages: List[ChatMessage]
245
+ max_tokens: int = 512
246
+ temperature: float = 0.8
247
+ top_k: int = 40
248
+ top_p: float = 0.9
249
+ repetition_penalty: float = 1.1
250
+ stream: bool = False
251
+
252
+ # ============================================================================
253
+ # Generation Functions
254
+ # ============================================================================
255
+
256
+ def generate_tokens(
257
+ prompt: str,
258
+ max_tokens: int = 512,
259
+ temperature: float = 0.8,
260
+ top_k: int = 40,
261
+ top_p: float = 0.9,
262
+ repetition_penalty: float = 1.1
263
+ ):
264
+ """Core generation function (yields token IDs)"""
265
+ global model, tokenizer, config, eos_token_id, fast_forward
266
+
267
+ # Tokenize
268
+ input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
269
+
270
+ if len(input_ids) == 0:
271
+ return
272
+
273
+ if len(input_ids) > config['max_position_embeddings'] - max_tokens:
274
+ input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):]
275
+
276
+ input_tensor = tf.constant([input_ids], dtype=tf.int32)
277
+ token_freq = {}
278
+
279
+ for step in range(max_tokens):
280
+ # Get logits
281
+ logits = fast_forward(input_tensor)
282
+ next_token_logits = logits[0, -1, :].numpy()
283
+
284
+ # Temperature
285
+ next_token_logits = next_token_logits / temperature
286
+
287
+ # Repetition penalty
288
+ if repetition_penalty != 1.0:
289
+ for token_id, freq in token_freq.items():
290
+ if token_id < len(next_token_logits):
291
+ next_token_logits[token_id] /= (repetition_penalty ** freq)
292
+
293
+ # Top-k filtering
294
+ if top_k > 0:
295
+ top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
296
+ top_k_logits = next_token_logits[top_k_indices]
297
+ top_k_probs = tf.nn.softmax(top_k_logits).numpy()
298
+
299
+ # Top-p sampling
300
+ if top_p < 1.0:
301
+ sorted_indices = np.argsort(top_k_probs)[::-1]
302
+ cumsum = np.cumsum(top_k_probs[sorted_indices])
303
+ cutoff_idx = np.searchsorted(cumsum, top_p)
304
+ nucleus_indices = sorted_indices[:cutoff_idx + 1]
305
+
306
+ nucleus_logits = top_k_logits[nucleus_indices]
307
+ nucleus_probs = tf.nn.softmax(nucleus_logits).numpy()
308
+
309
+ sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs)
310
+ next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]])
311
+ else:
312
+ sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs)
313
+ next_token_id = int(top_k_indices[sampled_idx])
314
+ else:
315
+ probs = tf.nn.softmax(next_token_logits).numpy()
316
+ next_token_id = np.random.choice(len(probs), p=probs)
317
+
318
+ # Stop on EOS
319
+ if next_token_id == eos_token_id:
320
+ break
321
+
322
+ token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
323
+
324
+ # Yield token
325
+ yield next_token_id
326
+
327
+ # Update input
328
+ input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
329
+
330
+ if input_tensor.shape[1] > config['max_position_embeddings']:
331
+ input_tensor = input_tensor[:, -config['max_position_embeddings']:]
332
+
333
+ def format_chat_prompt(messages: List[ChatMessage]) -> str:
334
+ """Format chat messages into prompt"""
335
+ prompt = ""
336
+ for msg in messages:
337
+ if msg.role == "user":
338
+ prompt += f"<|im_start|>user\n{msg.content}<|im_end|>\n"
339
+ elif msg.role == "assistant":
340
+ prompt += f"<|im_start|>assistant\n{msg.content}<|im_end|>\n"
341
+
342
+ prompt += "<|im_start|>assistant\n"
343
+ return prompt
344
+
345
+ # ============================================================================
346
+ # API Endpoints
347
+ # ============================================================================
348
+
349
+ @app.get("/")
350
+ async def root():
351
+ """Worker info"""
352
+ return {
353
+ "name": "SAM-Z-1 Worker",
354
+ "status": "ready" if model is not None else "loading",
355
+ "model": MODEL_REPO,
356
+ "endpoints": {
357
+ "generate": "/generate",
358
+ "chat": "/chat",
359
+ "health": "/health"
360
+ }
361
+ }
362
+
363
+ @app.get("/health")
364
+ async def health():
365
+ """Health check"""
366
+ return {
367
+ "status": "healthy" if model is not None else "loading",
368
+ "model_loaded": model is not None
369
+ }
370
+
371
+ @app.post("/generate")
372
+ async def generate(request: GenerateRequest):
373
+ """Generate text from prompt"""
374
+ if model is None:
375
+ raise HTTPException(status_code=503, detail="Model not loaded yet, please wait")
376
+
377
+ start_time = time.time()
378
+
379
+ if request.stream:
380
+ # Streaming response
381
+ async def stream_tokens():
382
+ generated_text = ""
383
+ token_count = 0
384
+
385
+ try:
386
+ for token_id in generate_tokens(
387
+ request.prompt,
388
+ max_tokens=request.max_tokens,
389
+ temperature=request.temperature,
390
+ top_k=request.top_k,
391
+ top_p=request.top_p,
392
+ repetition_penalty=request.repetition_penalty
393
+ ):
394
+ token_text = tokenizer.decode([token_id])
395
+ generated_text += token_text
396
+ token_count += 1
397
+
398
+ # Send chunk
399
+ yield f"data: {json.dumps({'text': token_text, 'total': generated_text})}\n\n"
400
+
401
+ # Small delay
402
+ await asyncio.sleep(0.001)
403
+
404
+ # Send final stats
405
+ elapsed = time.time() - start_time
406
+ yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed, 'tokens_per_sec': token_count/elapsed if elapsed > 0 else 0})}\n\n"
407
+
408
+ except Exception as e:
409
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
410
+
411
+ return StreamingResponse(stream_tokens(), media_type="text/event-stream")
412
+
413
+ else:
414
+ # Non-streaming response
415
+ generated_text = ""
416
+ token_count = 0
417
+
418
+ try:
419
+ for token_id in generate_tokens(
420
+ request.prompt,
421
+ max_tokens=request.max_tokens,
422
+ temperature=request.temperature,
423
+ top_k=request.top_k,
424
+ top_p=request.top_p,
425
+ repetition_penalty=request.repetition_penalty
426
+ ):
427
+ token_text = tokenizer.decode([token_id])
428
+ generated_text += token_text
429
+ token_count += 1
430
+
431
+ elapsed = time.time() - start_time
432
+
433
+ return {
434
+ "text": generated_text,
435
+ "tokens": token_count,
436
+ "time": elapsed,
437
+ "tokens_per_second": token_count / elapsed if elapsed > 0 else 0
438
+ }
439
+
440
+ except Exception as e:
441
+ raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
442
+
443
+ @app.post("/chat")
444
+ async def chat(request: ChatRequest):
445
+ """Chat completion"""
446
+ if model is None:
447
+ raise HTTPException(status_code=503, detail="Model not loaded yet, please wait")
448
+
449
+ # Format prompt
450
+ prompt = format_chat_prompt(request.messages)
451
+
452
+ start_time = time.time()
453
+
454
+ if request.stream:
455
+ # Streaming
456
+ async def stream_tokens():
457
+ generated_text = ""
458
+ token_count = 0
459
+
460
+ try:
461
+ for token_id in generate_tokens(
462
+ prompt,
463
+ max_tokens=request.max_tokens,
464
+ temperature=request.temperature,
465
+ top_k=request.top_k,
466
+ top_p=request.top_p,
467
+ repetition_penalty=request.repetition_penalty
468
+ ):
469
+ token_text = tokenizer.decode([token_id])
470
+ generated_text += token_text
471
+ token_count += 1
472
+
473
+ # Stop at end tag
474
+ if "<|im_end|>" in generated_text:
475
+ generated_text = generated_text.split("<|im_end|>")[0]
476
+ break
477
+
478
+ yield f"data: {json.dumps({'delta': token_text, 'content': generated_text})}\n\n"
479
+ await asyncio.sleep(0.001)
480
+
481
+ elapsed = time.time() - start_time
482
+ yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed, 'tokens_per_sec': token_count/elapsed if elapsed > 0 else 0})}\n\n"
483
+
484
+ except Exception as e:
485
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
486
+
487
+ return StreamingResponse(stream_tokens(), media_type="text/event-stream")
488
+
489
+ else:
490
+ # Non-streaming
491
+ generated_text = ""
492
+ token_count = 0
493
+
494
+ try:
495
+ for token_id in generate_tokens(
496
+ prompt,
497
+ max_tokens=request.max_tokens,
498
+ temperature=request.temperature,
499
+ top_k=request.top_k,
500
+ top_p=request.top_p,
501
+ repetition_penalty=request.repetition_penalty
502
+ ):
503
+ token_text = tokenizer.decode([token_id])
504
+ generated_text += token_text
505
+ token_count += 1
506
+
507
+ if "<|im_end|>" in generated_text:
508
+ generated_text = generated_text.split("<|im_end|>")[0]
509
+ break
510
+
511
+ elapsed = time.time() - start_time
512
+
513
+ return {
514
+ "message": {
515
+ "role": "assistant",
516
+ "content": generated_text.strip()
517
+ },
518
+ "tokens": token_count,
519
+ "time": elapsed,
520
+ "tokens_per_second": token_count / elapsed if elapsed > 0 else 0
521
+ }
522
+
523
+ except Exception as e:
524
+ raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
525
+
526
+ # ============================================================================
527
+ # Startup: Load Model
528
+ # ============================================================================
529
+
530
+ @app.on_event("startup")
531
+ async def load_model():
532
+ """Load model on startup"""
533
+ global model, tokenizer, config, eos_token_id, fast_forward
534
+
535
+ print("πŸš€ Loading SAM-Z-1 Model...")
536
+
537
+ try:
538
+ # Download model files
539
+ config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
540
+
541
+ # Try checkpoint first
542
+ try:
543
+ weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
544
+ print("βœ… Found checkpoint weights")
545
+ use_checkpoint = True
546
+ except:
547
+ print("⚠️ Checkpoint not found, using model.keras")
548
+ model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
549
+ use_checkpoint = False
550
+
551
+ # Load config
552
+ with open(config_path, 'r') as f:
553
+ config = json.load(f)
554
+
555
+ print(f"πŸ“¦ Config loaded: {config['num_hidden_layers']} layers")
556
+
557
+ # Create tokenizer
558
+ print("πŸ“¦ Creating tokenizer...")
559
+ from transformers import AutoTokenizer
560
+
561
+ hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
562
+ custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"]
563
+ hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens})
564
+
565
+ os.makedirs("./temp_tokenizer", exist_ok=True)
566
+ hf_tokenizer.save_pretrained("./temp_tokenizer")
567
+ tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
568
+
569
+ eos_token_id = config.get('eos_token_id', 50256)
570
+
571
+ print(f"βœ… Tokenizer ready: vocab size {tokenizer.get_vocab_size()}")
572
+
573
+ # Load model
574
+ print("πŸ”„ Loading model...")
575
+
576
+ if use_checkpoint:
577
+ # Build from config
578
+ model_config = {
579
+ 'vocab_size': config['vocab_size'],
580
+ 'd_model': config['hidden_size'],
581
+ 'n_layers': config['num_hidden_layers'],
582
+ 'n_heads': config['num_attention_heads'],
583
+ 'ff_mult': config['intermediate_size'] / config['hidden_size'],
584
+ 'max_len': config['max_position_embeddings'],
585
+ 'dropout': 0.1,
586
+ 'rope_theta': config['rope_theta']
587
+ }
588
+
589
+ model = SAM1Model(config=model_config)
590
+
591
+ # Build with dummy input
592
+ dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
593
+ _ = model(dummy_input, training=False)
594
+
595
+ print(f"βœ… Architecture built: {model.count_params():,} parameters")
596
+
597
+ # Load weights
598
+ model.load_weights(weights_path)
599
+ print("βœ… Weights loaded!")
600
+
601
+ else:
602
+ # Load full model
603
+ model = keras.models.load_model(model_path, compile=False)
604
+ print("βœ… Model loaded!")
605
+
606
+ # Create optimized inference function
607
+ @tf.function(reduce_retracing=True)
608
+ def optimized_forward(input_tensor):
609
+ return model(input_tensor, training=False)
610
+
611
+ fast_forward = optimized_forward
612
+
613
+ print("βœ… SAM-Z-1 Worker ready for inference! πŸš€")
614
+
615
+ except Exception as e:
616
+ print(f"❌ Failed to load model: {e}")
617
+ import traceback
618
+ traceback.print_exc()
619
+ raise
620
+
621
+ # ============================================================================
622
+ # Launch
623
+ # ============================================================================
624
+
625
+ if __name__ == "__main__":
626
+ import uvicorn
627
+ uvicorn.run(
628
+ app,
629
+ host="0.0.0.0",
630
+ port=7860,
631
+ log_level="info"
632
+ )