""" SAM-Z-1 Distributed Worker Node v4.0 Optimized for distributed gen/decode pipeline """ from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse, HTMLResponse from pydantic import BaseModel import tensorflow as tf import keras from huggingface_hub import hf_hub_download import json import os from tokenizers import Tokenizer import numpy as np import time from typing import List, Optional import asyncio app = FastAPI(title="SAM-Z-1 Distributed Worker", version="4.0.0") # ============================================================================ # Model Architecture # ============================================================================ @keras.saving.register_keras_serializable() class RotaryEmbedding(keras.layers.Layer): def __init__(self, dim, max_len=2048, theta=10000, **kwargs): super().__init__(**kwargs) self.dim = dim self.max_len = max_len self.theta = theta self.built_cache = False def build(self, input_shape): super().build(input_shape) def _build_cache(self): if not self.built_cache: inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) t = tf.range(self.max_len, dtype=tf.float32) freqs = tf.einsum("i,j->ij", t, inv_freq) emb = tf.concat([freqs, freqs], axis=-1) self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) self.built_cache = True def rotate_half(self, x): x1, x2 = tf.split(x, 2, axis=-1) return tf.concat([-x2, x1], axis=-1) def call(self, q, k): self._build_cache() seq_len = tf.shape(q)[2] dtype = q.dtype cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :] sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :] q_rotated = (q * cos) + (self.rotate_half(q) * sin) k_rotated = (k * cos) + (self.rotate_half(k) * sin) return q_rotated, k_rotated def get_config(self): config = super().get_config() config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) return config @keras.saving.register_keras_serializable() class RMSNorm(keras.layers.Layer): def __init__(self, epsilon=1e-5, **kwargs): super().__init__(**kwargs) self.epsilon = epsilon def build(self, input_shape): self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") def call(self, x): variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) return x * tf.math.rsqrt(variance + self.epsilon) * self.scale def get_config(self): config = super().get_config() config.update({"epsilon": self.epsilon}) return config @keras.saving.register_keras_serializable() class TransformerBlock(keras.layers.Layer): def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): super().__init__(**kwargs) self.d_model = d_model self.n_heads = n_heads self.ff_dim = ff_dim self.dropout_rate = dropout self.max_len = max_len self.rope_theta = rope_theta self.head_dim = d_model // n_heads self.layer_idx = layer_idx self.pre_attn_norm = RMSNorm() self.pre_ffn_norm = RMSNorm() self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj") self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj") self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj") self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj") self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta) self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj") self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj") self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj") self.dropout = keras.layers.Dropout(dropout) def call(self, x, training=None): B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model dtype = x.dtype res = x y = self.pre_attn_norm(x) q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) q, k = self.rope(q, k) scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) mask = tf.where( tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype) ) scores += mask attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v) attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D]) x = res + self.dropout(self.out_proj(attn), training=training) res = x y = self.pre_ffn_norm(x) ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) return res + self.dropout(ffn, training=training) def get_config(self): config = super().get_config() config.update({ "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": self.rope_theta, "layer_idx": self.layer_idx }) return config @keras.saving.register_keras_serializable() class SAM1Model(keras.Model): def __init__(self, **kwargs): super().__init__() if 'config' in kwargs and isinstance(kwargs['config'], dict): self.cfg = kwargs['config'] elif 'vocab_size' in kwargs: self.cfg = kwargs else: self.cfg = kwargs.get('cfg', kwargs) self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) block_args = { 'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta'] } self.blocks = [] for i in range(self.cfg['n_layers']): block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) self.blocks.append(block) self.norm = RMSNorm(name="final_norm") self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") def call(self, input_ids, training=None): x = self.embed(input_ids) for block in self.blocks: x = block(x, training=training) return self.lm_head(self.norm(x)) def get_config(self): base_config = super().get_config() base_config['config'] = self.cfg return base_config # ============================================================================ # Global State # ============================================================================ model = None tokenizer = None config = None eos_token_id = None fast_forward = None MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow" CACHE_DIR = "./model_cache" # Stats worker_stats = { "total_requests": 0, "total_tokens": 0, "decode_requests": 0, "uptime_start": time.time() } # ============================================================================ # Request Models # ============================================================================ class GenerateRequest(BaseModel): prompt: str max_tokens: int = 512 temperature: float = 0.8 top_k: int = 40 top_p: float = 0.9 repetition_penalty: float = 1.1 stream: bool = False return_token_ids: bool = False class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[ChatMessage] max_tokens: int = 512 temperature: float = 0.8 top_k: int = 40 top_p: float = 0.9 repetition_penalty: float = 1.1 stream: bool = False return_token_ids: bool = False class DecodeRequest(BaseModel): token_ids: List[int] class BatchDecodeRequest(BaseModel): batches: List[List[int]] # ============================================================================ # Generation Functions # ============================================================================ def generate_tokens( prompt: str, max_tokens: int = 512, temperature: float = 0.8, top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1, return_token_ids: bool = False ): """Core generation - yields (token_id, token_text or None)""" global model, tokenizer, config, eos_token_id, fast_forward input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id] if len(input_ids) == 0: return if len(input_ids) > config['max_position_embeddings'] - max_tokens: input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):] input_tensor = tf.constant([input_ids], dtype=tf.int32) token_freq = {} for step in range(max_tokens): logits = fast_forward(input_tensor) next_token_logits = logits[0, -1, :].numpy() next_token_logits = next_token_logits / temperature if repetition_penalty != 1.0: for token_id, freq in token_freq.items(): if token_id < len(next_token_logits): next_token_logits[token_id] /= (repetition_penalty ** freq) if top_k > 0: top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:] top_k_logits = next_token_logits[top_k_indices] top_k_probs = tf.nn.softmax(top_k_logits).numpy() if top_p < 1.0: sorted_indices = np.argsort(top_k_probs)[::-1] cumsum = np.cumsum(top_k_probs[sorted_indices]) cutoff_idx = np.searchsorted(cumsum, top_p) nucleus_indices = sorted_indices[:cutoff_idx + 1] nucleus_logits = top_k_logits[nucleus_indices] nucleus_probs = tf.nn.softmax(nucleus_logits).numpy() sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs) next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]]) else: sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs) next_token_id = int(top_k_indices[sampled_idx]) else: probs = tf.nn.softmax(next_token_logits).numpy() next_token_id = np.random.choice(len(probs), p=probs) if next_token_id == eos_token_id: break token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1 if return_token_ids: yield (next_token_id, None) else: token_text = tokenizer.decode([next_token_id]) yield (next_token_id, token_text) input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1) if input_tensor.shape[1] > config['max_position_embeddings']: input_tensor = input_tensor[:, -config['max_position_embeddings']:] def format_chat_prompt(messages: List[ChatMessage]) -> str: prompt = "" for msg in messages: if msg.role == "user": prompt += f"<|im_start|>user\n{msg.content}<|im_end|>\n" elif msg.role == "assistant": prompt += f"<|im_start|>assistant\n{msg.content}<|im_end|>\n" prompt += "<|im_start|>assistant\n" return prompt # ============================================================================ # Status Page # ============================================================================ @app.get("/", response_class=HTMLResponse) async def status_page(): """Worker status page""" return """ SAM-Z-1 Worker Node

⚙️ WORKER NODE ⚙️

SAM-Z-1 Distributed Worker v4.0
CHECKING STATUS...
Total Requests
--
Total Tokens
--
Decode Requests
--
Uptime
--

🚀 CAPABILITIES

Initializing...
""" # ============================================================================ # API Endpoints # ============================================================================ @app.get("/health") async def health(): return { "status": "healthy" if model is not None else "loading", "model_loaded": model is not None } @app.get("/stats") async def stats(): uptime = time.time() - worker_stats["uptime_start"] return { "total_requests": worker_stats["total_requests"], "total_tokens": worker_stats["total_tokens"], "decode_requests": worker_stats["decode_requests"], "uptime": uptime, "tokens_per_second": worker_stats["total_tokens"] / uptime if uptime > 0 else 0 } @app.post("/decode") async def decode(request: DecodeRequest): """Fast single decode""" if tokenizer is None: raise HTTPException(status_code=503, detail="Tokenizer not loaded") try: worker_stats["decode_requests"] += 1 text = tokenizer.decode(request.token_ids) return {"text": text} except Exception as e: raise HTTPException(status_code=500, detail=f"Decode error: {str(e)}") @app.post("/decode/batch") async def batch_decode(request: BatchDecodeRequest): """Optimized batch decoding for distributed pipeline""" if tokenizer is None: raise HTTPException(status_code=503, detail="Tokenizer not loaded") try: worker_stats["decode_requests"] += len(request.batches) results = [tokenizer.decode(batch) for batch in request.batches] return {"texts": results} except Exception as e: raise HTTPException(status_code=500, detail=f"Batch decode error: {str(e)}") @app.post("/generate") async def generate(request: GenerateRequest): """Generate text""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") worker_stats["total_requests"] += 1 start_time = time.time() if request.stream: async def stream_tokens(): generated_text = "" token_count = 0 try: for token_id, token_text in generate_tokens( request.prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, repetition_penalty=request.repetition_penalty, return_token_ids=request.return_token_ids ): token_count += 1 worker_stats["total_tokens"] += 1 if request.return_token_ids: yield f"data: {json.dumps({'token_id': token_id})}\n\n" else: generated_text += token_text yield f"data: {json.dumps({'text': token_text, 'total': generated_text})}\n\n" await asyncio.sleep(0.001) elapsed = time.time() - start_time yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n" except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse(stream_tokens(), media_type="text/event-stream") else: generated_text = "" token_count = 0 try: for token_id, token_text in generate_tokens( request.prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, repetition_penalty=request.repetition_penalty, return_token_ids=request.return_token_ids ): if not request.return_token_ids: generated_text += token_text token_count += 1 worker_stats["total_tokens"] += 1 elapsed = time.time() - start_time return { "text": generated_text, "tokens": token_count, "time": elapsed, "tokens_per_second": token_count / elapsed if elapsed > 0 else 0 } except Exception as e: raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") @app.post("/chat") async def chat(request: ChatRequest): """Chat completion""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") worker_stats["total_requests"] += 1 prompt = format_chat_prompt(request.messages) start_time = time.time() if request.stream: async def stream_tokens(): generated_text = "" token_count = 0 try: for token_id, token_text in generate_tokens( prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, repetition_penalty=request.repetition_penalty, return_token_ids=request.return_token_ids ): token_count += 1 worker_stats["total_tokens"] += 1 if request.return_token_ids: yield f"data: {json.dumps({'token_id': token_id})}\n\n" else: generated_text += token_text if "<|im_end|>" in generated_text: generated_text = generated_text.split("<|im_end|>")[0] break yield f"data: {json.dumps({'delta': token_text, 'content': generated_text})}\n\n" await asyncio.sleep(0.001) elapsed = time.time() - start_time yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n" except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse(stream_tokens(), media_type="text/event-stream") else: generated_text = "" token_count = 0 try: for token_id, token_text in generate_tokens( prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, repetition_penalty=request.repetition_penalty, return_token_ids=request.return_token_ids ): if not request.return_token_ids: generated_text += token_text if "<|im_end|>" in generated_text: generated_text = generated_text.split("<|im_end|>")[0] break token_count += 1 worker_stats["total_tokens"] += 1 elapsed = time.time() - start_time return { "message": { "role": "assistant", "content": generated_text.strip() }, "tokens": token_count, "time": elapsed, "tokens_per_second": token_count / elapsed if elapsed > 0 else 0 } except Exception as e: raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") # ============================================================================ # Model Loading # ============================================================================ @app.on_event("startup") async def load_model(): global model, tokenizer, config, eos_token_id, fast_forward print("🚀 Loading SAM-Z-1 Model...") try: config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR) try: weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR) print("✅ Found checkpoint weights") use_checkpoint = True except: print("⚠️ Checkpoint not found, using model.keras") model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR) use_checkpoint = False with open(config_path, 'r') as f: config = json.load(f) print(f"📦 Config loaded: {config['num_hidden_layers']} layers") print("📦 Creating tokenizer...") from transformers import AutoTokenizer hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") custom_tokens = ["<|im_start|>", "<|im_end|>", "", ""] hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens}) os.makedirs("./temp_tokenizer", exist_ok=True) hf_tokenizer.save_pretrained("./temp_tokenizer") tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json") eos_token_id = config.get('eos_token_id', 50256) print(f"✅ Tokenizer ready: vocab size {tokenizer.get_vocab_size()}") print("🔄 Loading model...") if use_checkpoint: model_config = { 'vocab_size': config['vocab_size'], 'd_model': config['hidden_size'], 'n_layers': config['num_hidden_layers'], 'n_heads': config['num_attention_heads'], 'ff_mult': config['intermediate_size'] / config['hidden_size'], 'max_len': config['max_position_embeddings'], 'dropout': 0.1, 'rope_theta': config['rope_theta'] } model = SAM1Model(config=model_config) dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32) _ = model(dummy_input, training=False) print(f"✅ Architecture built: {model.count_params():,} parameters") model.load_weights(weights_path) print("✅ Weights loaded!") else: model = keras.models.load_model(model_path, compile=False) print("✅ Model loaded!") @tf.function(reduce_retracing=True) def optimized_forward(input_tensor): return model(input_tensor, training=False) fast_forward = optimized_forward print("✅ SAM-Z-1 Distributed Worker ready! 🚀") print("🔥 Features enabled:") print(" - Full text generation") print(" - Token-only mode (distributed pipeline)") print(" - Batch decoding optimization") print(" - Streaming support") except Exception as e: print(f"❌ Failed to load model: {e}") import traceback traceback.print_exc() raise # ============================================================================ # Launch # ============================================================================ if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info" )