Bc-AI commited on
Commit
120f320
Β·
verified Β·
1 Parent(s): 1055547

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +398 -191
app.py CHANGED
@@ -1,6 +1,9 @@
1
  """
2
- SAM-Z-1 Distributed Worker Node v4.0
3
- Optimized for distributed gen/decode pipeline
 
 
 
4
  """
5
 
6
  from fastapi import FastAPI, HTTPException
@@ -14,10 +17,56 @@ import os
14
  from tokenizers import Tokenizer
15
  import numpy as np
16
  import time
17
- from typing import List, Optional
18
  import asyncio
19
 
20
- app = FastAPI(title="SAM-Z-1 Distributed Worker", version="4.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # ============================================================================
23
  # Model Architecture
@@ -201,24 +250,19 @@ class SAM1Model(keras.Model):
201
  return base_config
202
 
203
  # ============================================================================
204
- # Global State
205
  # ============================================================================
206
 
207
- model = None
208
- tokenizer = None
209
- config = None
210
- eos_token_id = None
211
- fast_forward = None
212
-
213
- MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
214
- CACHE_DIR = "./model_cache"
215
 
216
- # Stats
217
  worker_stats = {
218
  "total_requests": 0,
219
  "total_tokens": 0,
220
  "decode_requests": 0,
221
- "uptime_start": time.time()
 
222
  }
223
 
224
  # ============================================================================
@@ -234,6 +278,7 @@ class GenerateRequest(BaseModel):
234
  repetition_penalty: float = 1.1
235
  stream: bool = False
236
  return_token_ids: bool = False
 
237
 
238
  class ChatMessage(BaseModel):
239
  role: str
@@ -248,12 +293,70 @@ class ChatRequest(BaseModel):
248
  repetition_penalty: float = 1.1
249
  stream: bool = False
250
  return_token_ids: bool = False
 
251
 
252
  class DecodeRequest(BaseModel):
253
  token_ids: List[int]
 
254
 
255
  class BatchDecodeRequest(BaseModel):
256
  batches: List[List[int]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  # ============================================================================
259
  # Generation Functions
@@ -266,11 +369,22 @@ def generate_tokens(
266
  top_k: int = 40,
267
  top_p: float = 0.9,
268
  repetition_penalty: float = 1.1,
269
- return_token_ids: bool = False
 
270
  ):
271
- """Core generation - yields (token_id, token_text or None)"""
272
- global model, tokenizer, config, eos_token_id, fast_forward
273
 
 
 
 
 
 
 
 
 
 
 
274
  input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
275
 
276
  if len(input_ids) == 0:
@@ -349,26 +463,29 @@ def format_chat_prompt(messages: List[ChatMessage]) -> str:
349
 
350
  @app.get("/", response_class=HTMLResponse)
351
  async def status_page():
352
- """Worker status page"""
353
- return """
 
 
 
 
 
 
354
  <!DOCTYPE html>
355
  <html>
356
  <head>
357
- <title>SAM-Z-1 Worker Node</title>
358
  <style>
359
- * { margin: 0; padding: 0; box-sizing: border-box; }
360
- body {
361
  font-family: 'Courier New', monospace;
362
  background: linear-gradient(135deg, #1a1f3a 0%, #0a0e27 100%);
363
  color: #00bfff;
364
  padding: 20px;
365
  min-height: 100vh;
366
- }
367
- .container {
368
- max-width: 900px;
369
- margin: 0 auto;
370
- }
371
- .header {
372
  text-align: center;
373
  padding: 30px;
374
  background: rgba(0, 191, 255, 0.1);
@@ -376,93 +493,77 @@ async def status_page():
376
  border-radius: 10px;
377
  margin-bottom: 30px;
378
  box-shadow: 0 0 20px rgba(0, 191, 255, 0.3);
379
- }
380
- .header h1 {
381
  font-size: 2.5em;
382
  text-transform: uppercase;
383
  letter-spacing: 3px;
384
  animation: glow 2s ease-in-out infinite alternate;
385
- }
386
- @keyframes glow {
387
- from { text-shadow: 0 0 10px #00bfff; }
388
- to { text-shadow: 0 0 20px #00bfff, 0 0 30px #00bfff; }
389
- }
390
- .badge {
391
  display: inline-block;
392
  padding: 5px 15px;
393
  border-radius: 15px;
394
  font-size: 0.9em;
395
- margin-top: 10px;
396
- }
397
- .badge-ready {
398
  background: rgba(0, 255, 136, 0.2);
399
  border: 1px solid #00ff88;
400
  color: #00ff88;
401
- }
402
- .badge-loading {
403
  background: rgba(255, 165, 0, 0.2);
404
  border: 1px solid #ffa500;
405
  color: #ffa500;
406
- }
407
- .stats-grid {
408
  display: grid;
409
  grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
410
  gap: 20px;
411
  margin-bottom: 30px;
412
- }
413
- .stat-card {
414
  background: rgba(0, 191, 255, 0.05);
415
  border: 1px solid #00bfff;
416
  border-radius: 8px;
417
  padding: 20px;
418
  text-align: center;
419
- }
420
- .stat-label {
421
- font-size: 0.8em;
422
- opacity: 0.7;
423
- text-transform: uppercase;
424
- margin-bottom: 10px;
425
- }
426
- .stat-value {
427
- font-size: 2em;
428
- font-weight: bold;
429
- }
430
- .features {
431
  background: rgba(0, 191, 255, 0.05);
432
  border: 1px solid #00bfff;
433
  border-radius: 8px;
434
  padding: 20px;
435
- }
436
- .features h3 {
437
- margin-bottom: 15px;
438
- }
439
- .feature-list {
440
- list-style: none;
441
- padding: 0;
442
- }
443
- .feature-list li {
444
  padding: 10px;
445
  margin: 5px 0;
446
  background: rgba(0, 191, 255, 0.1);
447
  border-radius: 5px;
448
- }
449
- .feature-list li:before {
450
- content: "⚑ ";
451
- color: #00ff88;
452
- }
453
- .timestamp {
454
- text-align: center;
455
- margin-top: 20px;
456
- opacity: 0.5;
457
- }
458
  </style>
459
  </head>
460
  <body>
461
  <div class="container">
462
  <div class="header">
463
  <h1>βš™οΈ WORKER NODE βš™οΈ</h1>
464
- <div>SAM-Z-1 Distributed Worker v4.0</div>
465
- <div class="badge" id="status-badge">CHECKING STATUS...</div>
 
 
 
466
  </div>
467
 
468
  <div class="stats-grid" id="stats">
@@ -484,14 +585,23 @@ async def status_page():
484
  </div>
485
  </div>
486
 
 
 
 
 
 
 
 
487
  <div class="features">
488
  <h3>πŸš€ CAPABILITIES</h3>
489
  <ul class="feature-list">
490
- <li>Full Text Generation</li>
491
- <li>Token-Only Mode (for distributed pipeline)</li>
492
- <li>High-Speed Batch Decoding</li>
493
- <li>Chat Completion</li>
494
- <li>Streaming & Non-Streaming</li>
 
 
495
  </ul>
496
  </div>
497
 
@@ -499,21 +609,8 @@ async def status_page():
499
  </div>
500
 
501
  <script>
502
- async function updateStats() {
503
- try {
504
- const response = await fetch('/health');
505
- const data = await response.json();
506
-
507
- const badge = document.getElementById('status-badge');
508
- if (data.model_loaded) {
509
- badge.textContent = 'βœ… READY FOR INFERENCE';
510
- badge.className = 'badge badge-ready';
511
- } else {
512
- badge.textContent = '⏳ LOADING MODEL...';
513
- badge.className = 'badge badge-loading';
514
- }
515
-
516
- // Fetch stats
517
  const statsRes = await fetch('/stats');
518
  const stats = await statsRes.json();
519
 
@@ -525,16 +622,15 @@ async def status_page():
525
  const h = Math.floor(uptime / 3600);
526
  const m = Math.floor((uptime % 3600) / 60);
527
  const s = uptime % 60;
528
- document.getElementById('uptime').textContent = `${h}h ${m}m ${s}s`;
529
 
530
  document.getElementById('timestamp').textContent =
531
- `Last update: ${new Date().toLocaleTimeString()}`;
532
- } catch (e) {
533
  console.error('Failed to update stats:', e);
534
- }
535
- }
536
 
537
- // Update every second
538
  setInterval(updateStats, 1000);
539
  updateStats();
540
  </script>
@@ -549,8 +645,38 @@ async def status_page():
549
  @app.get("/health")
550
  async def health():
551
  return {
552
- "status": "healthy" if model is not None else "loading",
553
- "model_loaded": model is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  }
555
 
556
  @app.get("/stats")
@@ -561,17 +687,16 @@ async def stats():
561
  "total_tokens": worker_stats["total_tokens"],
562
  "decode_requests": worker_stats["decode_requests"],
563
  "uptime": uptime,
564
- "tokens_per_second": worker_stats["total_tokens"] / uptime if uptime > 0 else 0
 
565
  }
566
 
567
  @app.post("/decode")
568
  async def decode(request: DecodeRequest):
569
- """Fast single decode"""
570
- if tokenizer is None:
571
- raise HTTPException(status_code=503, detail="Tokenizer not loaded")
572
-
573
  try:
574
  worker_stats["decode_requests"] += 1
 
575
  text = tokenizer.decode(request.token_ids)
576
  return {"text": text}
577
  except Exception as e:
@@ -579,12 +704,10 @@ async def decode(request: DecodeRequest):
579
 
580
  @app.post("/decode/batch")
581
  async def batch_decode(request: BatchDecodeRequest):
582
- """Optimized batch decoding for distributed pipeline"""
583
- if tokenizer is None:
584
- raise HTTPException(status_code=503, detail="Tokenizer not loaded")
585
-
586
  try:
587
  worker_stats["decode_requests"] += len(request.batches)
 
588
  results = [tokenizer.decode(batch) for batch in request.batches]
589
  return {"texts": results}
590
  except Exception as e:
@@ -592,9 +715,15 @@ async def batch_decode(request: BatchDecodeRequest):
592
 
593
  @app.post("/generate")
594
  async def generate(request: GenerateRequest):
595
- """Generate text"""
596
- if model is None:
597
- raise HTTPException(status_code=503, detail="Model not loaded")
 
 
 
 
 
 
598
 
599
  worker_stats["total_requests"] += 1
600
  start_time = time.time()
@@ -612,7 +741,8 @@ async def generate(request: GenerateRequest):
612
  top_k=request.top_k,
613
  top_p=request.top_p,
614
  repetition_penalty=request.repetition_penalty,
615
- return_token_ids=request.return_token_ids
 
616
  ):
617
  token_count += 1
618
  worker_stats["total_tokens"] += 1
@@ -626,7 +756,7 @@ async def generate(request: GenerateRequest):
626
  await asyncio.sleep(0.001)
627
 
628
  elapsed = time.time() - start_time
629
- yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n"
630
 
631
  except Exception as e:
632
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
@@ -645,7 +775,8 @@ async def generate(request: GenerateRequest):
645
  top_k=request.top_k,
646
  top_p=request.top_p,
647
  repetition_penalty=request.repetition_penalty,
648
- return_token_ids=request.return_token_ids
 
649
  ):
650
  if not request.return_token_ids:
651
  generated_text += token_text
@@ -658,7 +789,8 @@ async def generate(request: GenerateRequest):
658
  "text": generated_text,
659
  "tokens": token_count,
660
  "time": elapsed,
661
- "tokens_per_second": token_count / elapsed if elapsed > 0 else 0
 
662
  }
663
 
664
  except Exception as e:
@@ -666,9 +798,15 @@ async def generate(request: GenerateRequest):
666
 
667
  @app.post("/chat")
668
  async def chat(request: ChatRequest):
669
- """Chat completion"""
670
- if model is None:
671
- raise HTTPException(status_code=503, detail="Model not loaded")
 
 
 
 
 
 
672
 
673
  worker_stats["total_requests"] += 1
674
  prompt = format_chat_prompt(request.messages)
@@ -687,7 +825,8 @@ async def chat(request: ChatRequest):
687
  top_k=request.top_k,
688
  top_p=request.top_p,
689
  repetition_penalty=request.repetition_penalty,
690
- return_token_ids=request.return_token_ids
 
691
  ):
692
  token_count += 1
693
  worker_stats["total_tokens"] += 1
@@ -706,7 +845,7 @@ async def chat(request: ChatRequest):
706
  await asyncio.sleep(0.001)
707
 
708
  elapsed = time.time() - start_time
709
- yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n"
710
 
711
  except Exception as e:
712
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
@@ -725,7 +864,8 @@ async def chat(request: ChatRequest):
725
  top_k=request.top_k,
726
  top_p=request.top_p,
727
  repetition_penalty=request.repetition_penalty,
728
- return_token_ids=request.return_token_ids
 
729
  ):
730
  if not request.return_token_ids:
731
  generated_text += token_text
@@ -746,7 +886,8 @@ async def chat(request: ChatRequest):
746
  },
747
  "tokens": token_count,
748
  "time": elapsed,
749
- "tokens_per_second": token_count / elapsed if elapsed > 0 else 0
 
750
  }
751
 
752
  except Exception as e:
@@ -756,86 +897,152 @@ async def chat(request: ChatRequest):
756
  # Model Loading
757
  # ============================================================================
758
 
759
- @app.on_event("startup")
760
- async def load_model():
761
- global model, tokenizer, config, eos_token_id, fast_forward
762
-
763
- print("πŸš€ Loading SAM-Z-1 Model...")
764
 
765
  try:
766
- config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
 
 
767
 
768
- try:
769
- weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
770
- print("βœ… Found checkpoint weights")
771
- use_checkpoint = True
772
- except:
773
- print("⚠️ Checkpoint not found, using model.keras")
774
- model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
775
- use_checkpoint = False
776
 
777
- with open(config_path, 'r') as f:
778
- config = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
 
780
- print(f"πŸ“¦ Config loaded: {config['num_hidden_layers']} layers")
 
 
 
 
 
 
 
 
 
 
781
 
782
- print("πŸ“¦ Creating tokenizer...")
783
- from transformers import AutoTokenizer
784
 
785
- hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
786
- custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"]
787
- hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens})
788
 
789
- os.makedirs("./temp_tokenizer", exist_ok=True)
790
- hf_tokenizer.save_pretrained("./temp_tokenizer")
791
- tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
 
 
 
792
 
793
- eos_token_id = config.get('eos_token_id', 50256)
 
 
 
 
 
794
 
795
- print(f"βœ… Tokenizer ready: vocab size {tokenizer.get_vocab_size()}")
 
 
 
 
 
 
 
796
 
797
- print("πŸ”„ Loading model...")
 
 
 
798
 
799
- if use_checkpoint:
800
- model_config = {
801
- 'vocab_size': config['vocab_size'],
802
- 'd_model': config['hidden_size'],
803
- 'n_layers': config['num_hidden_layers'],
804
- 'n_heads': config['num_attention_heads'],
805
- 'ff_mult': config['intermediate_size'] / config['hidden_size'],
806
- 'max_len': config['max_position_embeddings'],
807
- 'dropout': 0.1,
808
- 'rope_theta': config['rope_theta']
809
- }
810
-
811
- model = SAM1Model(config=model_config)
812
- dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
813
- _ = model(dummy_input, training=False)
814
-
815
- print(f"βœ… Architecture built: {model.count_params():,} parameters")
816
-
817
- model.load_weights(weights_path)
818
- print("βœ… Weights loaded!")
819
 
 
 
 
 
 
 
820
  else:
821
- model = keras.models.load_model(model_path, compile=False)
822
- print("βœ… Model loaded!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823
 
824
- @tf.function(reduce_retracing=True)
825
- def optimized_forward(input_tensor):
826
- return model(input_tensor, training=False)
827
 
828
- fast_forward = optimized_forward
 
 
 
829
 
830
- print("βœ… SAM-Z-1 Distributed Worker ready! πŸš€")
831
- print("πŸ”₯ Features enabled:")
832
- print(" - Full text generation")
833
- print(" - Token-only mode (distributed pipeline)")
834
- print(" - Batch decoding optimization")
835
- print(" - Streaming support")
836
 
837
  except Exception as e:
838
- print(f"❌ Failed to load model: {e}")
839
  import traceback
840
  traceback.print_exc()
841
  raise
 
1
  """
2
+ SAM-Z-1 Distributed Worker Node v5.0
3
+ - Supports BOTH old SAM-Z-1 AND 4 new SAM-X-1 models
4
+ - Different tokenizers and vocabularies per model family
5
+ - Auto version detection
6
+ - Backward compatible with v4 head nodes
7
  """
8
 
9
  from fastapi import FastAPI, HTTPException
 
17
  from tokenizers import Tokenizer
18
  import numpy as np
19
  import time
20
+ from typing import List, Optional, Dict
21
  import asyncio
22
 
23
+ app = FastAPI(title="SAM-Z-1 Distributed Worker", version="5.0.0")
24
+
25
+ # ============================================================================
26
+ # Configuration - ALL 5 MODELS
27
+ # ============================================================================
28
+
29
+ MODEL_REGISTRY = {
30
+ # Original SAM-Z-1 (keep this!)
31
+ "SAM-Z-1": {
32
+ "repo": "Smilyai-labs/Sam-Z-1-tensorflow",
33
+ "weights": "ckpt.weights.h5",
34
+ "config": "config.json",
35
+ "tokenizer_repo": "Smilyai-labs/Sam-Z-1-tensorflow",
36
+ "family": "sam-z" # Different tokenizer family
37
+ },
38
+ # New SAM-X-1 family (different tokenizer!)
39
+ "SAM-X-1-Large": {
40
+ "repo": "Smilyai-labs/Sam-1x-instruct",
41
+ "weights": "ckpt.weights.h5",
42
+ "config": None,
43
+ "tokenizer_repo": "Smilyai-labs/Sam-1-large-it-0002",
44
+ "family": "sam-x"
45
+ },
46
+ "SAM-X-1-Fast": {
47
+ "repo": "Smilyai-labs/Sam-X-1-fast",
48
+ "weights": "sam1_fast_finetuned.weights.h5",
49
+ "config": "sam1_fast_finetuned_config.json",
50
+ "tokenizer_repo": "Smilyai-labs/Sam-1-large-it-0002",
51
+ "family": "sam-x"
52
+ },
53
+ "SAM-X-1-Mini": {
54
+ "repo": "Smilyai-labs/Sam-X-1-Mini",
55
+ "weights": "sam1_mini_finetuned.weights.h5",
56
+ "config": "sam1_mini_finetuned_config.json",
57
+ "tokenizer_repo": "Smilyai-labs/Sam-1-large-it-0002",
58
+ "family": "sam-x"
59
+ },
60
+ "SAM-X-1-Nano": {
61
+ "repo": "Smilyai-labs/Sam-X-1-Nano",
62
+ "weights": "sam1_nano_finetuned.weights.h5",
63
+ "config": "sam1_nano_finetuned_config.json",
64
+ "tokenizer_repo": "Smilyai-labs/Sam-1-large-it-0002",
65
+ "family": "sam-x"
66
+ }
67
+ }
68
+
69
+ CACHE_DIR = "./model_cache"
70
 
71
  # ============================================================================
72
  # Model Architecture
 
250
  return base_config
251
 
252
  # ============================================================================
253
+ # Global State - Separate tokenizers per family!
254
  # ============================================================================
255
 
256
+ loaded_models = {} # Dict[model_name, (model, fast_forward, config, tokenizer, eos_token_id)]
257
+ tokenizer_cache = {} # Dict[family, (tokenizer, eos_token_id)]
258
+ current_model = None
 
 
 
 
 
259
 
 
260
  worker_stats = {
261
  "total_requests": 0,
262
  "total_tokens": 0,
263
  "decode_requests": 0,
264
+ "uptime_start": time.time(),
265
+ "model_usage": {}
266
  }
267
 
268
  # ============================================================================
 
278
  repetition_penalty: float = 1.1
279
  stream: bool = False
280
  return_token_ids: bool = False
281
+ model: Optional[str] = None
282
 
283
  class ChatMessage(BaseModel):
284
  role: str
 
293
  repetition_penalty: float = 1.1
294
  stream: bool = False
295
  return_token_ids: bool = False
296
+ model: Optional[str] = None
297
 
298
  class DecodeRequest(BaseModel):
299
  token_ids: List[int]
300
+ model: Optional[str] = None # Need to know which tokenizer to use!
301
 
302
  class BatchDecodeRequest(BaseModel):
303
  batches: List[List[int]]
304
+ model: Optional[str] = None
305
+
306
+ # ============================================================================
307
+ # Tokenizer Management
308
+ # ============================================================================
309
+
310
+ async def load_tokenizer(family: str, repo: str) -> tuple:
311
+ """Load tokenizer for a model family"""
312
+ if family in tokenizer_cache:
313
+ return tokenizer_cache[family]
314
+
315
+ print(f" πŸ”€ Loading tokenizer for {family} family from {repo}...")
316
+
317
+ try:
318
+ from transformers import AutoTokenizer
319
+
320
+ hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
321
+ custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"]
322
+ hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens})
323
+
324
+ os.makedirs(f"./temp_tokenizer_{family}", exist_ok=True)
325
+ hf_tokenizer.save_pretrained(f"./temp_tokenizer_{family}")
326
+ tokenizer = Tokenizer.from_file(f"./temp_tokenizer_{family}/tokenizer.json")
327
+
328
+ eos_token = "<|endoftext|>"
329
+ eos_token_id = tokenizer.token_to_id(eos_token)
330
+
331
+ if eos_token_id is None:
332
+ tokenizer.add_special_tokens([eos_token])
333
+ eos_token_id = tokenizer.token_to_id(eos_token)
334
+
335
+ tokenizer_cache[family] = (tokenizer, eos_token_id)
336
+ print(f" βœ… Tokenizer ready (vocab size: {tokenizer.get_vocab_size()}, EOS: {eos_token_id})")
337
+
338
+ return tokenizer, eos_token_id
339
+
340
+ except Exception as e:
341
+ print(f" ⚠️ Tokenizer load failed: {e}")
342
+ raise
343
+
344
+ def get_tokenizer_for_model(model_name: str):
345
+ """Get the correct tokenizer for a model"""
346
+ if not model_name or model_name not in loaded_models:
347
+ model_name = current_model
348
+
349
+ if model_name in loaded_models:
350
+ _, _, _, tokenizer, eos_id = loaded_models[model_name]
351
+ return tokenizer, eos_id
352
+
353
+ # Fallback to first available
354
+ if loaded_models:
355
+ first_model = list(loaded_models.keys())[0]
356
+ _, _, _, tokenizer, eos_id = loaded_models[first_model]
357
+ return tokenizer, eos_id
358
+
359
+ raise HTTPException(status_code=503, detail="No models loaded")
360
 
361
  # ============================================================================
362
  # Generation Functions
 
369
  top_k: int = 40,
370
  top_p: float = 0.9,
371
  repetition_penalty: float = 1.1,
372
+ return_token_ids: bool = False,
373
+ model_name: Optional[str] = None
374
  ):
375
+ """Core generation with correct tokenizer per model"""
376
+ global loaded_models, current_model
377
 
378
+ # Select model
379
+ if model_name and model_name in loaded_models:
380
+ model, fast_forward, config, tokenizer, eos_token_id = loaded_models[model_name]
381
+ elif current_model:
382
+ model, fast_forward, config, tokenizer, eos_token_id = loaded_models[current_model]
383
+ else:
384
+ model_name = list(loaded_models.keys())[0]
385
+ model, fast_forward, config, tokenizer, eos_token_id = loaded_models[model_name]
386
+
387
+ # Encode with model's tokenizer
388
  input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
389
 
390
  if len(input_ids) == 0:
 
463
 
464
  @app.get("/", response_class=HTMLResponse)
465
  async def status_page():
466
+ models_html = ""
467
+ for model_name in loaded_models.keys():
468
+ usage = worker_stats["model_usage"].get(model_name, 0)
469
+ _, _, _, tokenizer, _ = loaded_models[model_name]
470
+ vocab_size = tokenizer.get_vocab_size()
471
+ models_html += f'<li><strong>{model_name}</strong> - Vocab: {vocab_size} - Used: {usage}x</li>'
472
+
473
+ return f"""
474
  <!DOCTYPE html>
475
  <html>
476
  <head>
477
+ <title>SAM Worker v5.0 - Multi-Model</title>
478
  <style>
479
+ * {{ margin: 0; padding: 0; box-sizing: border-box; }}
480
+ body {{
481
  font-family: 'Courier New', monospace;
482
  background: linear-gradient(135deg, #1a1f3a 0%, #0a0e27 100%);
483
  color: #00bfff;
484
  padding: 20px;
485
  min-height: 100vh;
486
+ }}
487
+ .container {{ max-width: 1000px; margin: 0 auto; }}
488
+ .header {{
 
 
 
489
  text-align: center;
490
  padding: 30px;
491
  background: rgba(0, 191, 255, 0.1);
 
493
  border-radius: 10px;
494
  margin-bottom: 30px;
495
  box-shadow: 0 0 20px rgba(0, 191, 255, 0.3);
496
+ }}
497
+ .header h1 {{
498
  font-size: 2.5em;
499
  text-transform: uppercase;
500
  letter-spacing: 3px;
501
  animation: glow 2s ease-in-out infinite alternate;
502
+ }}
503
+ @keyframes glow {{
504
+ from {{ text-shadow: 0 0 10px #00bfff; }}
505
+ to {{ text-shadow: 0 0 20px #00bfff, 0 0 30px #00bfff; }}
506
+ }}
507
+ .badge {{
508
  display: inline-block;
509
  padding: 5px 15px;
510
  border-radius: 15px;
511
  font-size: 0.9em;
512
+ margin: 5px;
513
+ }}
514
+ .badge-v5 {{
515
  background: rgba(0, 255, 136, 0.2);
516
  border: 1px solid #00ff88;
517
  color: #00ff88;
518
+ }}
519
+ .badge-multi {{
520
  background: rgba(255, 165, 0, 0.2);
521
  border: 1px solid #ffa500;
522
  color: #ffa500;
523
+ }}
524
+ .stats-grid {{
525
  display: grid;
526
  grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
527
  gap: 20px;
528
  margin-bottom: 30px;
529
+ }}
530
+ .stat-card {{
531
  background: rgba(0, 191, 255, 0.05);
532
  border: 1px solid #00bfff;
533
  border-radius: 8px;
534
  padding: 20px;
535
  text-align: center;
536
+ }}
537
+ .stat-label {{ font-size: 0.8em; opacity: 0.7; text-transform: uppercase; margin-bottom: 10px; }}
538
+ .stat-value {{ font-size: 2em; font-weight: bold; }}
539
+ .features {{
 
 
 
 
 
 
 
 
540
  background: rgba(0, 191, 255, 0.05);
541
  border: 1px solid #00bfff;
542
  border-radius: 8px;
543
  padding: 20px;
544
+ margin-bottom: 20px;
545
+ }}
546
+ .features h3 {{ margin-bottom: 15px; }}
547
+ .feature-list {{ list-style: none; padding: 0; }}
548
+ .feature-list li {{
 
 
 
 
549
  padding: 10px;
550
  margin: 5px 0;
551
  background: rgba(0, 191, 255, 0.1);
552
  border-radius: 5px;
553
+ border-left: 3px solid #00ff88;
554
+ }}
555
+ .timestamp {{ text-align: center; margin-top: 20px; opacity: 0.5; }}
 
 
 
 
 
 
 
556
  </style>
557
  </head>
558
  <body>
559
  <div class="container">
560
  <div class="header">
561
  <h1>βš™οΈ WORKER NODE βš™οΈ</h1>
562
+ <div>SAM-Z-1 Distributed Worker v5.0</div>
563
+ <div>
564
+ <span class="badge badge-v5">V5 PROTOCOL</span>
565
+ <span class="badge badge-multi">{len(loaded_models)} MODELS</span>
566
+ </div>
567
  </div>
568
 
569
  <div class="stats-grid" id="stats">
 
585
  </div>
586
  </div>
587
 
588
+ <div class="features">
589
+ <h3>πŸ€– LOADED MODELS ({len(loaded_models)})</h3>
590
+ <ul class="feature-list">
591
+ {models_html}
592
+ </ul>
593
+ </div>
594
+
595
  <div class="features">
596
  <h3>πŸš€ CAPABILITIES</h3>
597
  <ul class="feature-list">
598
+ <li>βœ… Original SAM-Z-1 (preserved)</li>
599
+ <li>βœ… 4 new SAM-X-1 models</li>
600
+ <li>βœ… Separate tokenizers per family</li>
601
+ <li>βœ… Multi-model selection</li>
602
+ <li>βœ… Token & batch decoding</li>
603
+ <li>βœ… Streaming support</li>
604
+ <li>βœ… Auto version detection</li>
605
  </ul>
606
  </div>
607
 
 
609
  </div>
610
 
611
  <script>
612
+ async function updateStats() {{
613
+ try {{
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  const statsRes = await fetch('/stats');
615
  const stats = await statsRes.json();
616
 
 
622
  const h = Math.floor(uptime / 3600);
623
  const m = Math.floor((uptime % 3600) / 60);
624
  const s = uptime % 60;
625
+ document.getElementById('uptime').textContent = `${{h}}h ${{m}}m ${{s}}s`;
626
 
627
  document.getElementById('timestamp').textContent =
628
+ `Last update: ${{new Date().toLocaleTimeString()}}`;
629
+ }} catch (e) {{
630
  console.error('Failed to update stats:', e);
631
+ }}
632
+ }}
633
 
 
634
  setInterval(updateStats, 1000);
635
  updateStats();
636
  </script>
 
645
  @app.get("/health")
646
  async def health():
647
  return {
648
+ "status": "healthy" if loaded_models else "loading",
649
+ "model_loaded": len(loaded_models) > 0,
650
+ "models_count": len(loaded_models)
651
+ }
652
+
653
+ @app.get("/info")
654
+ async def worker_info():
655
+ """Worker information for version detection"""
656
+ return {
657
+ "version": "v5",
658
+ "models": list(loaded_models.keys()),
659
+ "features": [
660
+ "multi_model",
661
+ "model_selection",
662
+ "separate_tokenizers",
663
+ "token_generation",
664
+ "batch_decoding",
665
+ "streaming"
666
+ ],
667
+ "model_families": {
668
+ "sam-z": [m for m, info in MODEL_REGISTRY.items() if info["family"] == "sam-z"],
669
+ "sam-x": [m for m, info in MODEL_REGISTRY.items() if info["family"] == "sam-x"]
670
+ }
671
+ }
672
+
673
+ @app.get("/models")
674
+ async def list_models():
675
+ """List available models"""
676
+ return {
677
+ "models": list(loaded_models.keys()),
678
+ "default": current_model,
679
+ "count": len(loaded_models)
680
  }
681
 
682
  @app.get("/stats")
 
687
  "total_tokens": worker_stats["total_tokens"],
688
  "decode_requests": worker_stats["decode_requests"],
689
  "uptime": uptime,
690
+ "tokens_per_second": worker_stats["total_tokens"] / uptime if uptime > 0 else 0,
691
+ "model_usage": worker_stats["model_usage"]
692
  }
693
 
694
  @app.post("/decode")
695
  async def decode(request: DecodeRequest):
696
+ """Fast single decode - uses correct tokenizer"""
 
 
 
697
  try:
698
  worker_stats["decode_requests"] += 1
699
+ tokenizer, _ = get_tokenizer_for_model(request.model)
700
  text = tokenizer.decode(request.token_ids)
701
  return {"text": text}
702
  except Exception as e:
 
704
 
705
  @app.post("/decode/batch")
706
  async def batch_decode(request: BatchDecodeRequest):
707
+ """Optimized batch decoding - uses correct tokenizer"""
 
 
 
708
  try:
709
  worker_stats["decode_requests"] += len(request.batches)
710
+ tokenizer, _ = get_tokenizer_for_model(request.model)
711
  results = [tokenizer.decode(batch) for batch in request.batches]
712
  return {"texts": results}
713
  except Exception as e:
 
715
 
716
  @app.post("/generate")
717
  async def generate(request: GenerateRequest):
718
+ """Generate text with model selection"""
719
+ if not loaded_models:
720
+ raise HTTPException(status_code=503, detail="No models loaded")
721
+
722
+ # Track model usage
723
+ model_name = request.model or current_model
724
+ if model_name not in worker_stats["model_usage"]:
725
+ worker_stats["model_usage"][model_name] = 0
726
+ worker_stats["model_usage"][model_name] += 1
727
 
728
  worker_stats["total_requests"] += 1
729
  start_time = time.time()
 
741
  top_k=request.top_k,
742
  top_p=request.top_p,
743
  repetition_penalty=request.repetition_penalty,
744
+ return_token_ids=request.return_token_ids,
745
+ model_name=request.model
746
  ):
747
  token_count += 1
748
  worker_stats["total_tokens"] += 1
 
756
  await asyncio.sleep(0.001)
757
 
758
  elapsed = time.time() - start_time
759
+ yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed, 'model': model_name})}\n\n"
760
 
761
  except Exception as e:
762
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
 
775
  top_k=request.top_k,
776
  top_p=request.top_p,
777
  repetition_penalty=request.repetition_penalty,
778
+ return_token_ids=request.return_token_ids,
779
+ model_name=request.model
780
  ):
781
  if not request.return_token_ids:
782
  generated_text += token_text
 
789
  "text": generated_text,
790
  "tokens": token_count,
791
  "time": elapsed,
792
+ "tokens_per_second": token_count / elapsed if elapsed > 0 else 0,
793
+ "model": model_name
794
  }
795
 
796
  except Exception as e:
 
798
 
799
  @app.post("/chat")
800
  async def chat(request: ChatRequest):
801
+ """Chat completion with model selection"""
802
+ if not loaded_models:
803
+ raise HTTPException(status_code=503, detail="No models loaded")
804
+
805
+ # Track model usage
806
+ model_name = request.model or current_model
807
+ if model_name not in worker_stats["model_usage"]:
808
+ worker_stats["model_usage"][model_name] = 0
809
+ worker_stats["model_usage"][model_name] += 1
810
 
811
  worker_stats["total_requests"] += 1
812
  prompt = format_chat_prompt(request.messages)
 
825
  top_k=request.top_k,
826
  top_p=request.top_p,
827
  repetition_penalty=request.repetition_penalty,
828
+ return_token_ids=request.return_token_ids,
829
+ model_name=request.model
830
  ):
831
  token_count += 1
832
  worker_stats["total_tokens"] += 1
 
845
  await asyncio.sleep(0.001)
846
 
847
  elapsed = time.time() - start_time
848
+ yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed, 'model': model_name})}\n\n"
849
 
850
  except Exception as e:
851
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
 
864
  top_k=request.top_k,
865
  top_p=request.top_p,
866
  repetition_penalty=request.repetition_penalty,
867
+ return_token_ids=request.return_token_ids,
868
+ model_name=request.model
869
  ):
870
  if not request.return_token_ids:
871
  generated_text += token_text
 
886
  },
887
  "tokens": token_count,
888
  "time": elapsed,
889
+ "tokens_per_second": token_count / elapsed if elapsed > 0 else 0,
890
+ "model": model_name
891
  }
892
 
893
  except Exception as e:
 
897
  # Model Loading
898
  # ============================================================================
899
 
900
+ async def load_single_model(model_name: str, model_info: dict) -> bool:
901
+ """Load a single model with its tokenizer"""
902
+ global loaded_models, current_model
 
 
903
 
904
  try:
905
+ print(f"\n⏳ Loading: {model_name} ({model_info['family']} family)")
906
+ print(f" Repo: {model_info['repo']}")
907
+ print(f" Weights: {model_info['weights']}")
908
 
909
+ # Load tokenizer for this family
910
+ tokenizer, eos_token_id = await load_tokenizer(
911
+ model_info['family'],
912
+ model_info['tokenizer_repo']
913
+ )
 
 
 
914
 
915
+ # Load config
916
+ if model_info['config']:
917
+ print(f" Config: {model_info['config']}")
918
+ config_path = hf_hub_download(
919
+ repo_id=model_info['repo'],
920
+ filename=model_info['config'],
921
+ cache_dir=CACHE_DIR
922
+ )
923
+ with open(config_path, 'r') as f:
924
+ config_raw = json.load(f)
925
+ else:
926
+ # Load base config for Large model
927
+ print(f" Loading base config from tokenizer repo...")
928
+ config_path = hf_hub_download(
929
+ repo_id=model_info['tokenizer_repo'],
930
+ filename="config.json",
931
+ cache_dir=CACHE_DIR
932
+ )
933
+ with open(config_path, 'r') as f:
934
+ config_raw = json.load(f)
935
 
936
+ # Convert to model format
937
+ model_config = {
938
+ 'vocab_size': config_raw['vocab_size'],
939
+ 'd_model': config_raw['hidden_size'],
940
+ 'n_heads': config_raw['num_attention_heads'],
941
+ 'ff_mult': config_raw['intermediate_size'] / config_raw['hidden_size'],
942
+ 'dropout': config_raw.get('dropout', 0.0),
943
+ 'max_len': config_raw['max_position_embeddings'],
944
+ 'rope_theta': config_raw['rope_theta'],
945
+ 'n_layers': config_raw['num_hidden_layers']
946
+ }
947
 
948
+ # Add for config object
949
+ model_config['max_position_embeddings'] = config_raw['max_position_embeddings']
950
 
951
+ print(f" πŸ“ Architecture: {model_config['n_layers']} layers, {model_config['n_heads']} heads")
 
 
952
 
953
+ # Load weights
954
+ weights_path = hf_hub_download(
955
+ repo_id=model_info['repo'],
956
+ filename=model_info['weights'],
957
+ cache_dir=CACHE_DIR
958
+ )
959
 
960
+ # Build model
961
+ model = SAM1Model(**model_config)
962
+ dummy_input = tf.zeros((1, 1), dtype=tf.int32)
963
+ model(dummy_input)
964
+ model.load_weights(weights_path)
965
+ model.trainable = False
966
 
967
+ # Create optimized forward pass
968
+ @tf.function(
969
+ input_signature=[tf.TensorSpec(shape=[1, None], dtype=tf.int32)],
970
+ jit_compile=True,
971
+ reduce_retracing=True
972
+ )
973
+ def fast_predict(inputs):
974
+ return model(inputs, training=False)
975
 
976
+ # Warm up
977
+ print(f" πŸ”₯ Warming up...")
978
+ dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
979
+ _ = fast_predict(dummy)
980
 
981
+ # Store model with its tokenizer
982
+ loaded_models[model_name] = (model, fast_predict, model_config, tokenizer, eos_token_id)
983
+
984
+ # Set as default if first
985
+ if current_model is None:
986
+ current_model = model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987
 
988
+ # Count parameters
989
+ total_params = sum(np.prod(w.shape) for w in model.weights)
990
+ if total_params >= 1e9:
991
+ param_str = f"{total_params/1e9:.2f}B"
992
+ elif total_params >= 1e6:
993
+ param_str = f"{total_params/1e6:.2f}M"
994
  else:
995
+ param_str = f"{total_params/1e3:.2f}K"
996
+
997
+ print(f" βœ… Loaded successfully!")
998
+ print(f" πŸ“Š Parameters: {param_str}")
999
+ print(f" πŸ”€ Tokenizer vocab: {tokenizer.get_vocab_size()}")
1000
+
1001
+ return True
1002
+
1003
+ except Exception as e:
1004
+ print(f" ⚠️ Failed to load {model_name}: {e}")
1005
+ import traceback
1006
+ traceback.print_exc()
1007
+ return False
1008
+
1009
+ @app.on_event("startup")
1010
+ async def load_models():
1011
+ global loaded_models, current_model
1012
+
1013
+ print("="*80)
1014
+ print("πŸš€ SAM-Z-1 Worker Node v5.0 - Multi-Model with Separate Tokenizers".center(80))
1015
+ print("="*80)
1016
+
1017
+ try:
1018
+ # Load all models
1019
+ print("\n" + "="*80)
1020
+ print("πŸ“¦ LOADING ALL 5 MODELS".center(80))
1021
+ print("="*80)
1022
+
1023
+ loaded_count = 0
1024
+ for model_name, model_info in MODEL_REGISTRY.items():
1025
+ success = await load_single_model(model_name, model_info)
1026
+ if success:
1027
+ loaded_count += 1
1028
+
1029
+ if loaded_count == 0:
1030
+ raise RuntimeError("❌ No models loaded successfully!")
1031
 
1032
+ print(f"\n{'='*80}")
1033
+ print(f"βœ… Successfully loaded {loaded_count}/{len(MODEL_REGISTRY)} models")
1034
+ print(f"πŸ“Œ Default model: {current_model}")
1035
 
1036
+ # Show tokenizer families
1037
+ print(f"\nπŸ”€ Tokenizer Families:")
1038
+ print(f" SAM-Z family: {len([m for m, i in MODEL_REGISTRY.items() if i['family'] == 'sam-z'])} model(s)")
1039
+ print(f" SAM-X family: {len([m for m, i in MODEL_REGISTRY.items() if i['family'] == 'sam-x'])} model(s)")
1040
 
1041
+ print(f"\nπŸš€ Worker ready for inference!")
1042
+ print(f"{'='*80}\n")
 
 
 
 
1043
 
1044
  except Exception as e:
1045
+ print(f"\n❌ Failed to initialize worker: {e}")
1046
  import traceback
1047
  traceback.print_exc()
1048
  raise