Bc-AI commited on
Commit
fa2c7b6
Β·
verified Β·
1 Parent(s): f59a624

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -80
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- SAM-Z-1 Worker Node - Complete Implementation
3
- Loads model and processes generation requests
4
  """
5
 
6
  from fastapi import FastAPI, HTTPException
@@ -14,13 +14,13 @@ import os
14
  from tokenizers import Tokenizer
15
  import numpy as np
16
  import time
17
- from typing import List
18
  import asyncio
19
 
20
- app = FastAPI(title="SAM-Z-1 Worker", version="1.0.0")
21
 
22
  # ============================================================================
23
- # Model Architecture Definitions
24
  # ============================================================================
25
 
26
  @keras.saving.register_keras_serializable()
@@ -36,7 +36,6 @@ class RotaryEmbedding(keras.layers.Layer):
36
  super().build(input_shape)
37
 
38
  def _build_cache(self):
39
- """Build RoPE cache on first forward pass"""
40
  if not self.built_cache:
41
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
42
  t = tf.range(self.max_len, dtype=tf.float32)
@@ -53,7 +52,6 @@ class RotaryEmbedding(keras.layers.Layer):
53
 
54
  def call(self, q, k):
55
  self._build_cache()
56
-
57
  seq_len = tf.shape(q)[2]
58
  dtype = q.dtype
59
  cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
@@ -69,7 +67,6 @@ class RotaryEmbedding(keras.layers.Layer):
69
  config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
70
  return config
71
 
72
-
73
  @keras.saving.register_keras_serializable()
74
  class RMSNorm(keras.layers.Layer):
75
  def __init__(self, epsilon=1e-5, **kwargs):
@@ -88,7 +85,6 @@ class RMSNorm(keras.layers.Layer):
88
  config.update({"epsilon": self.epsilon})
89
  return config
90
 
91
-
92
  @keras.saving.register_keras_serializable()
93
  class TransformerBlock(keras.layers.Layer):
94
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
@@ -122,7 +118,6 @@ class TransformerBlock(keras.layers.Layer):
122
  B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
123
  dtype = x.dtype
124
 
125
- # Attention
126
  res = x
127
  y = self.pre_attn_norm(x)
128
 
@@ -133,7 +128,6 @@ class TransformerBlock(keras.layers.Layer):
133
  q, k = self.rope(q, k)
134
 
135
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
136
-
137
  mask = tf.where(
138
  tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
139
  tf.constant(-1e9, dtype=dtype),
@@ -145,7 +139,6 @@ class TransformerBlock(keras.layers.Layer):
145
  attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
146
  x = res + self.dropout(self.out_proj(attn), training=training)
147
 
148
- # FFN (SwiGLU)
149
  res = x
150
  y = self.pre_ffn_norm(x)
151
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
@@ -165,7 +158,6 @@ class TransformerBlock(keras.layers.Layer):
165
  })
166
  return config
167
 
168
-
169
  @keras.saving.register_keras_serializable()
170
  class SAM1Model(keras.Model):
171
  def __init__(self, **kwargs):
@@ -199,10 +191,8 @@ class SAM1Model(keras.Model):
199
 
200
  def call(self, input_ids, training=None):
201
  x = self.embed(input_ids)
202
-
203
  for block in self.blocks:
204
  x = block(x, training=training)
205
-
206
  return self.lm_head(self.norm(x))
207
 
208
  def get_config(self):
@@ -235,6 +225,7 @@ class GenerateRequest(BaseModel):
235
  top_p: float = 0.9
236
  repetition_penalty: float = 1.1
237
  stream: bool = False
 
238
 
239
  class ChatMessage(BaseModel):
240
  role: str
@@ -248,6 +239,10 @@ class ChatRequest(BaseModel):
248
  top_p: float = 0.9
249
  repetition_penalty: float = 1.1
250
  stream: bool = False
 
 
 
 
251
 
252
  # ============================================================================
253
  # Generation Functions
@@ -259,12 +254,16 @@ def generate_tokens(
259
  temperature: float = 0.8,
260
  top_k: int = 40,
261
  top_p: float = 0.9,
262
- repetition_penalty: float = 1.1
 
263
  ):
264
- """Core generation function (yields token IDs)"""
 
 
 
 
265
  global model, tokenizer, config, eos_token_id, fast_forward
266
 
267
- # Tokenize
268
  input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
269
 
270
  if len(input_ids) == 0:
@@ -277,26 +276,21 @@ def generate_tokens(
277
  token_freq = {}
278
 
279
  for step in range(max_tokens):
280
- # Get logits
281
  logits = fast_forward(input_tensor)
282
  next_token_logits = logits[0, -1, :].numpy()
283
 
284
- # Temperature
285
  next_token_logits = next_token_logits / temperature
286
 
287
- # Repetition penalty
288
  if repetition_penalty != 1.0:
289
  for token_id, freq in token_freq.items():
290
  if token_id < len(next_token_logits):
291
  next_token_logits[token_id] /= (repetition_penalty ** freq)
292
 
293
- # Top-k filtering
294
  if top_k > 0:
295
  top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
296
  top_k_logits = next_token_logits[top_k_indices]
297
  top_k_probs = tf.nn.softmax(top_k_logits).numpy()
298
 
299
- # Top-p sampling
300
  if top_p < 1.0:
301
  sorted_indices = np.argsort(top_k_probs)[::-1]
302
  cumsum = np.cumsum(top_k_probs[sorted_indices])
@@ -315,16 +309,18 @@ def generate_tokens(
315
  probs = tf.nn.softmax(next_token_logits).numpy()
316
  next_token_id = np.random.choice(len(probs), p=probs)
317
 
318
- # Stop on EOS
319
  if next_token_id == eos_token_id:
320
  break
321
 
322
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
323
 
324
- # Yield token
325
- yield next_token_id
 
 
 
 
326
 
327
- # Update input
328
  input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
329
 
330
  if input_tensor.shape[1] > config['max_position_embeddings']:
@@ -350,12 +346,15 @@ def format_chat_prompt(messages: List[ChatMessage]) -> str:
350
  async def root():
351
  """Worker info"""
352
  return {
353
- "name": "SAM-Z-1 Worker",
 
354
  "status": "ready" if model is not None else "loading",
355
  "model": MODEL_REPO,
 
356
  "endpoints": {
357
  "generate": "/generate",
358
  "chat": "/chat",
 
359
  "health": "/health"
360
  }
361
  }
@@ -368,11 +367,27 @@ async def health():
368
  "model_loaded": model is not None
369
  }
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  @app.post("/generate")
372
  async def generate(request: GenerateRequest):
373
- """Generate text from prompt"""
374
  if model is None:
375
- raise HTTPException(status_code=503, detail="Model not loaded yet, please wait")
376
 
377
  start_time = time.time()
378
 
@@ -383,27 +398,29 @@ async def generate(request: GenerateRequest):
383
  token_count = 0
384
 
385
  try:
386
- for token_id in generate_tokens(
387
  request.prompt,
388
  max_tokens=request.max_tokens,
389
  temperature=request.temperature,
390
  top_k=request.top_k,
391
  top_p=request.top_p,
392
- repetition_penalty=request.repetition_penalty
 
393
  ):
394
- token_text = tokenizer.decode([token_id])
395
- generated_text += token_text
396
  token_count += 1
397
 
398
- # Send chunk
399
- yield f"data: {json.dumps({'text': token_text, 'total': generated_text})}\n\n"
 
 
 
 
 
400
 
401
- # Small delay
402
  await asyncio.sleep(0.001)
403
 
404
- # Send final stats
405
  elapsed = time.time() - start_time
406
- yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed, 'tokens_per_sec': token_count/elapsed if elapsed > 0 else 0})}\n\n"
407
 
408
  except Exception as e:
409
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
@@ -411,21 +428,22 @@ async def generate(request: GenerateRequest):
411
  return StreamingResponse(stream_tokens(), media_type="text/event-stream")
412
 
413
  else:
414
- # Non-streaming response
415
  generated_text = ""
416
  token_count = 0
417
 
418
  try:
419
- for token_id in generate_tokens(
420
  request.prompt,
421
  max_tokens=request.max_tokens,
422
  temperature=request.temperature,
423
  top_k=request.top_k,
424
  top_p=request.top_p,
425
- repetition_penalty=request.repetition_penalty
 
426
  ):
427
- token_text = tokenizer.decode([token_id])
428
- generated_text += token_text
429
  token_count += 1
430
 
431
  elapsed = time.time() - start_time
@@ -442,44 +460,45 @@ async def generate(request: GenerateRequest):
442
 
443
  @app.post("/chat")
444
  async def chat(request: ChatRequest):
445
- """Chat completion"""
446
  if model is None:
447
- raise HTTPException(status_code=503, detail="Model not loaded yet, please wait")
448
 
449
- # Format prompt
450
  prompt = format_chat_prompt(request.messages)
451
-
452
  start_time = time.time()
453
 
454
  if request.stream:
455
- # Streaming
456
  async def stream_tokens():
457
  generated_text = ""
458
  token_count = 0
459
 
460
  try:
461
- for token_id in generate_tokens(
462
  prompt,
463
  max_tokens=request.max_tokens,
464
  temperature=request.temperature,
465
  top_k=request.top_k,
466
  top_p=request.top_p,
467
- repetition_penalty=request.repetition_penalty
 
468
  ):
469
- token_text = tokenizer.decode([token_id])
470
- generated_text += token_text
471
  token_count += 1
472
 
473
- # Stop at end tag
474
- if "<|im_end|>" in generated_text:
475
- generated_text = generated_text.split("<|im_end|>")[0]
476
- break
 
 
 
 
 
 
477
 
478
- yield f"data: {json.dumps({'delta': token_text, 'content': generated_text})}\n\n"
479
  await asyncio.sleep(0.001)
480
 
481
  elapsed = time.time() - start_time
482
- yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed, 'tokens_per_sec': token_count/elapsed if elapsed > 0 else 0})}\n\n"
483
 
484
  except Exception as e:
485
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
@@ -487,26 +506,27 @@ async def chat(request: ChatRequest):
487
  return StreamingResponse(stream_tokens(), media_type="text/event-stream")
488
 
489
  else:
490
- # Non-streaming
491
  generated_text = ""
492
  token_count = 0
493
 
494
  try:
495
- for token_id in generate_tokens(
496
  prompt,
497
  max_tokens=request.max_tokens,
498
  temperature=request.temperature,
499
  top_k=request.top_k,
500
  top_p=request.top_p,
501
- repetition_penalty=request.repetition_penalty
 
502
  ):
503
- token_text = tokenizer.decode([token_id])
504
- generated_text += token_text
505
- token_count += 1
 
 
 
506
 
507
- if "<|im_end|>" in generated_text:
508
- generated_text = generated_text.split("<|im_end|>")[0]
509
- break
510
 
511
  elapsed = time.time() - start_time
512
 
@@ -535,10 +555,8 @@ async def load_model():
535
  print("πŸš€ Loading SAM-Z-1 Model...")
536
 
537
  try:
538
- # Download model files
539
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
540
 
541
- # Try checkpoint first
542
  try:
543
  weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
544
  print("βœ… Found checkpoint weights")
@@ -548,13 +566,11 @@ async def load_model():
548
  model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
549
  use_checkpoint = False
550
 
551
- # Load config
552
  with open(config_path, 'r') as f:
553
  config = json.load(f)
554
 
555
  print(f"πŸ“¦ Config loaded: {config['num_hidden_layers']} layers")
556
 
557
- # Create tokenizer
558
  print("πŸ“¦ Creating tokenizer...")
559
  from transformers import AutoTokenizer
560
 
@@ -570,11 +586,9 @@ async def load_model():
570
 
571
  print(f"βœ… Tokenizer ready: vocab size {tokenizer.get_vocab_size()}")
572
 
573
- # Load model
574
  print("πŸ”„ Loading model...")
575
 
576
  if use_checkpoint:
577
- # Build from config
578
  model_config = {
579
  'vocab_size': config['vocab_size'],
580
  'd_model': config['hidden_size'],
@@ -587,30 +601,25 @@ async def load_model():
587
  }
588
 
589
  model = SAM1Model(config=model_config)
590
-
591
- # Build with dummy input
592
  dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
593
  _ = model(dummy_input, training=False)
594
 
595
  print(f"βœ… Architecture built: {model.count_params():,} parameters")
596
 
597
- # Load weights
598
  model.load_weights(weights_path)
599
  print("βœ… Weights loaded!")
600
 
601
  else:
602
- # Load full model
603
  model = keras.models.load_model(model_path, compile=False)
604
  print("βœ… Model loaded!")
605
 
606
- # Create optimized inference function
607
  @tf.function(reduce_retracing=True)
608
  def optimized_forward(input_tensor):
609
  return model(input_tensor, training=False)
610
 
611
  fast_forward = optimized_forward
612
 
613
- print("βœ… SAM-Z-1 Worker ready for inference! πŸš€")
614
 
615
  except Exception as e:
616
  print(f"❌ Failed to load model: {e}")
 
1
  """
2
+ SAM-Z-1 Smart Worker Node
3
+ Supports both full generation and gen/decode split modes
4
  """
5
 
6
  from fastapi import FastAPI, HTTPException
 
14
  from tokenizers import Tokenizer
15
  import numpy as np
16
  import time
17
+ from typing import List, Optional
18
  import asyncio
19
 
20
+ app = FastAPI(title="SAM-Z-1 Smart Worker", version="3.0.0")
21
 
22
  # ============================================================================
23
+ # Model Architecture (same as before)
24
  # ============================================================================
25
 
26
  @keras.saving.register_keras_serializable()
 
36
  super().build(input_shape)
37
 
38
  def _build_cache(self):
 
39
  if not self.built_cache:
40
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
41
  t = tf.range(self.max_len, dtype=tf.float32)
 
52
 
53
  def call(self, q, k):
54
  self._build_cache()
 
55
  seq_len = tf.shape(q)[2]
56
  dtype = q.dtype
57
  cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
 
67
  config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
68
  return config
69
 
 
70
  @keras.saving.register_keras_serializable()
71
  class RMSNorm(keras.layers.Layer):
72
  def __init__(self, epsilon=1e-5, **kwargs):
 
85
  config.update({"epsilon": self.epsilon})
86
  return config
87
 
 
88
  @keras.saving.register_keras_serializable()
89
  class TransformerBlock(keras.layers.Layer):
90
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
 
118
  B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
119
  dtype = x.dtype
120
 
 
121
  res = x
122
  y = self.pre_attn_norm(x)
123
 
 
128
  q, k = self.rope(q, k)
129
 
130
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
 
131
  mask = tf.where(
132
  tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
133
  tf.constant(-1e9, dtype=dtype),
 
139
  attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
140
  x = res + self.dropout(self.out_proj(attn), training=training)
141
 
 
142
  res = x
143
  y = self.pre_ffn_norm(x)
144
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
 
158
  })
159
  return config
160
 
 
161
  @keras.saving.register_keras_serializable()
162
  class SAM1Model(keras.Model):
163
  def __init__(self, **kwargs):
 
191
 
192
  def call(self, input_ids, training=None):
193
  x = self.embed(input_ids)
 
194
  for block in self.blocks:
195
  x = block(x, training=training)
 
196
  return self.lm_head(self.norm(x))
197
 
198
  def get_config(self):
 
225
  top_p: float = 0.9
226
  repetition_penalty: float = 1.1
227
  stream: bool = False
228
+ return_token_ids: bool = False # NEW: for gen/decode split
229
 
230
  class ChatMessage(BaseModel):
231
  role: str
 
239
  top_p: float = 0.9
240
  repetition_penalty: float = 1.1
241
  stream: bool = False
242
+ return_token_ids: bool = False # NEW
243
+
244
+ class DecodeRequest(BaseModel):
245
+ token_ids: List[int]
246
 
247
  # ============================================================================
248
  # Generation Functions
 
254
  temperature: float = 0.8,
255
  top_k: int = 40,
256
  top_p: float = 0.9,
257
+ repetition_penalty: float = 1.1,
258
+ return_token_ids: bool = False
259
  ):
260
+ """
261
+ Core generation function
262
+ If return_token_ids=True, yields (token_id, None)
263
+ If return_token_ids=False, yields (token_id, token_text)
264
+ """
265
  global model, tokenizer, config, eos_token_id, fast_forward
266
 
 
267
  input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
268
 
269
  if len(input_ids) == 0:
 
276
  token_freq = {}
277
 
278
  for step in range(max_tokens):
 
279
  logits = fast_forward(input_tensor)
280
  next_token_logits = logits[0, -1, :].numpy()
281
 
 
282
  next_token_logits = next_token_logits / temperature
283
 
 
284
  if repetition_penalty != 1.0:
285
  for token_id, freq in token_freq.items():
286
  if token_id < len(next_token_logits):
287
  next_token_logits[token_id] /= (repetition_penalty ** freq)
288
 
 
289
  if top_k > 0:
290
  top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
291
  top_k_logits = next_token_logits[top_k_indices]
292
  top_k_probs = tf.nn.softmax(top_k_logits).numpy()
293
 
 
294
  if top_p < 1.0:
295
  sorted_indices = np.argsort(top_k_probs)[::-1]
296
  cumsum = np.cumsum(top_k_probs[sorted_indices])
 
309
  probs = tf.nn.softmax(next_token_logits).numpy()
310
  next_token_id = np.random.choice(len(probs), p=probs)
311
 
 
312
  if next_token_id == eos_token_id:
313
  break
314
 
315
  token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
316
 
317
+ # Yield token ID and optionally decoded text
318
+ if return_token_ids:
319
+ yield (next_token_id, None)
320
+ else:
321
+ token_text = tokenizer.decode([next_token_id])
322
+ yield (next_token_id, token_text)
323
 
 
324
  input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
325
 
326
  if input_tensor.shape[1] > config['max_position_embeddings']:
 
346
  async def root():
347
  """Worker info"""
348
  return {
349
+ "name": "SAM-Z-1 Smart Worker",
350
+ "version": "3.0.0",
351
  "status": "ready" if model is not None else "loading",
352
  "model": MODEL_REPO,
353
+ "features": ["full_generation", "token_only_mode", "decode_only_mode"],
354
  "endpoints": {
355
  "generate": "/generate",
356
  "chat": "/chat",
357
+ "decode": "/decode",
358
  "health": "/health"
359
  }
360
  }
 
367
  "model_loaded": model is not None
368
  }
369
 
370
+ @app.post("/decode")
371
+ async def decode(request: DecodeRequest):
372
+ """
373
+ DECODE ONLY endpoint
374
+ Takes token IDs and returns decoded text
375
+ This is the bottleneck we're parallelizing!
376
+ """
377
+ if tokenizer is None:
378
+ raise HTTPException(status_code=503, detail="Tokenizer not loaded")
379
+
380
+ try:
381
+ text = tokenizer.decode(request.token_ids)
382
+ return {"text": text}
383
+ except Exception as e:
384
+ raise HTTPException(status_code=500, detail=f"Decode error: {str(e)}")
385
+
386
  @app.post("/generate")
387
  async def generate(request: GenerateRequest):
388
+ """Generate text - supports both full gen and token-only mode"""
389
  if model is None:
390
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
391
 
392
  start_time = time.time()
393
 
 
398
  token_count = 0
399
 
400
  try:
401
+ for token_id, token_text in generate_tokens(
402
  request.prompt,
403
  max_tokens=request.max_tokens,
404
  temperature=request.temperature,
405
  top_k=request.top_k,
406
  top_p=request.top_p,
407
+ repetition_penalty=request.repetition_penalty,
408
+ return_token_ids=request.return_token_ids
409
  ):
 
 
410
  token_count += 1
411
 
412
+ if request.return_token_ids:
413
+ # TOKEN-ONLY mode for gen/decode split
414
+ yield f"data: {json.dumps({'token_id': token_id})}\n\n"
415
+ else:
416
+ # FULL mode with text
417
+ generated_text += token_text
418
+ yield f"data: {json.dumps({'text': token_text, 'total': generated_text})}\n\n"
419
 
 
420
  await asyncio.sleep(0.001)
421
 
 
422
  elapsed = time.time() - start_time
423
+ yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n"
424
 
425
  except Exception as e:
426
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
 
428
  return StreamingResponse(stream_tokens(), media_type="text/event-stream")
429
 
430
  else:
431
+ # Non-streaming
432
  generated_text = ""
433
  token_count = 0
434
 
435
  try:
436
+ for token_id, token_text in generate_tokens(
437
  request.prompt,
438
  max_tokens=request.max_tokens,
439
  temperature=request.temperature,
440
  top_k=request.top_k,
441
  top_p=request.top_p,
442
+ repetition_penalty=request.repetition_penalty,
443
+ return_token_ids=request.return_token_ids
444
  ):
445
+ if not request.return_token_ids:
446
+ generated_text += token_text
447
  token_count += 1
448
 
449
  elapsed = time.time() - start_time
 
460
 
461
  @app.post("/chat")
462
  async def chat(request: ChatRequest):
463
+ """Chat completion - supports both modes"""
464
  if model is None:
465
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
466
 
 
467
  prompt = format_chat_prompt(request.messages)
 
468
  start_time = time.time()
469
 
470
  if request.stream:
 
471
  async def stream_tokens():
472
  generated_text = ""
473
  token_count = 0
474
 
475
  try:
476
+ for token_id, token_text in generate_tokens(
477
  prompt,
478
  max_tokens=request.max_tokens,
479
  temperature=request.temperature,
480
  top_k=request.top_k,
481
  top_p=request.top_p,
482
+ repetition_penalty=request.repetition_penalty,
483
+ return_token_ids=request.return_token_ids
484
  ):
 
 
485
  token_count += 1
486
 
487
+ if request.return_token_ids:
488
+ yield f"data: {json.dumps({'token_id': token_id})}\n\n"
489
+ else:
490
+ generated_text += token_text
491
+
492
+ if "<|im_end|>" in generated_text:
493
+ generated_text = generated_text.split("<|im_end|>")[0]
494
+ break
495
+
496
+ yield f"data: {json.dumps({'delta': token_text, 'content': generated_text})}\n\n"
497
 
 
498
  await asyncio.sleep(0.001)
499
 
500
  elapsed = time.time() - start_time
501
+ yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n"
502
 
503
  except Exception as e:
504
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
 
506
  return StreamingResponse(stream_tokens(), media_type="text/event-stream")
507
 
508
  else:
 
509
  generated_text = ""
510
  token_count = 0
511
 
512
  try:
513
+ for token_id, token_text in generate_tokens(
514
  prompt,
515
  max_tokens=request.max_tokens,
516
  temperature=request.temperature,
517
  top_k=request.top_k,
518
  top_p=request.top_p,
519
+ repetition_penalty=request.repetition_penalty,
520
+ return_token_ids=request.return_token_ids
521
  ):
522
+ if not request.return_token_ids:
523
+ generated_text += token_text
524
+
525
+ if "<|im_end|>" in generated_text:
526
+ generated_text = generated_text.split("<|im_end|>")[0]
527
+ break
528
 
529
+ token_count += 1
 
 
530
 
531
  elapsed = time.time() - start_time
532
 
 
555
  print("πŸš€ Loading SAM-Z-1 Model...")
556
 
557
  try:
 
558
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
559
 
 
560
  try:
561
  weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
562
  print("βœ… Found checkpoint weights")
 
566
  model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
567
  use_checkpoint = False
568
 
 
569
  with open(config_path, 'r') as f:
570
  config = json.load(f)
571
 
572
  print(f"πŸ“¦ Config loaded: {config['num_hidden_layers']} layers")
573
 
 
574
  print("πŸ“¦ Creating tokenizer...")
575
  from transformers import AutoTokenizer
576
 
 
586
 
587
  print(f"βœ… Tokenizer ready: vocab size {tokenizer.get_vocab_size()}")
588
 
 
589
  print("πŸ”„ Loading model...")
590
 
591
  if use_checkpoint:
 
592
  model_config = {
593
  'vocab_size': config['vocab_size'],
594
  'd_model': config['hidden_size'],
 
601
  }
602
 
603
  model = SAM1Model(config=model_config)
 
 
604
  dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
605
  _ = model(dummy_input, training=False)
606
 
607
  print(f"βœ… Architecture built: {model.count_params():,} parameters")
608
 
 
609
  model.load_weights(weights_path)
610
  print("βœ… Weights loaded!")
611
 
612
  else:
 
613
  model = keras.models.load_model(model_path, compile=False)
614
  print("βœ… Model loaded!")
615
 
 
616
  @tf.function(reduce_retracing=True)
617
  def optimized_forward(input_tensor):
618
  return model(input_tensor, training=False)
619
 
620
  fast_forward = optimized_forward
621
 
622
+ print("βœ… SAM-Z-1 Smart Worker ready! πŸš€")
623
 
624
  except Exception as e:
625
  print(f"❌ Failed to load model: {e}")