""" SAM-Z-1 Distributed Worker Node v4.0 Optimized for distributed gen/decode pipeline """ from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse, HTMLResponse from pydantic import BaseModel import tensorflow as tf import keras from huggingface_hub import hf_hub_download import json import os from tokenizers import Tokenizer import numpy as np import time from typing import List, Optional import asyncio app = FastAPI(title="SAM-Z-1 Distributed Worker", version="4.0.0") # ============================================================================ # Model Architecture # ============================================================================ @keras.saving.register_keras_serializable() class RotaryEmbedding(keras.layers.Layer): def __init__(self, dim, max_len=2048, theta=10000, **kwargs): super().__init__(**kwargs) self.dim = dim self.max_len = max_len self.theta = theta self.built_cache = False def build(self, input_shape): super().build(input_shape) def _build_cache(self): if not self.built_cache: inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) t = tf.range(self.max_len, dtype=tf.float32) freqs = tf.einsum("i,j->ij", t, inv_freq) emb = tf.concat([freqs, freqs], axis=-1) self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) self.built_cache = True def rotate_half(self, x): x1, x2 = tf.split(x, 2, axis=-1) return tf.concat([-x2, x1], axis=-1) def call(self, q, k): self._build_cache() seq_len = tf.shape(q)[2] dtype = q.dtype cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :] sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :] q_rotated = (q * cos) + (self.rotate_half(q) * sin) k_rotated = (k * cos) + (self.rotate_half(k) * sin) return q_rotated, k_rotated def get_config(self): config = super().get_config() config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) return config @keras.saving.register_keras_serializable() class RMSNorm(keras.layers.Layer): def __init__(self, epsilon=1e-5, **kwargs): super().__init__(**kwargs) self.epsilon = epsilon def build(self, input_shape): self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") def call(self, x): variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) return x * tf.math.rsqrt(variance + self.epsilon) * self.scale def get_config(self): config = super().get_config() config.update({"epsilon": self.epsilon}) return config @keras.saving.register_keras_serializable() class TransformerBlock(keras.layers.Layer): def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): super().__init__(**kwargs) self.d_model = d_model self.n_heads = n_heads self.ff_dim = ff_dim self.dropout_rate = dropout self.max_len = max_len self.rope_theta = rope_theta self.head_dim = d_model // n_heads self.layer_idx = layer_idx self.pre_attn_norm = RMSNorm() self.pre_ffn_norm = RMSNorm() self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj") self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj") self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj") self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj") self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta) self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj") self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj") self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj") self.dropout = keras.layers.Dropout(dropout) def call(self, x, training=None): B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model dtype = x.dtype res = x y = self.pre_attn_norm(x) q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) q, k = self.rope(q, k) scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) mask = tf.where( tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype) ) scores += mask attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v) attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D]) x = res + self.dropout(self.out_proj(attn), training=training) res = x y = self.pre_ffn_norm(x) ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) return res + self.dropout(ffn, training=training) def get_config(self): config = super().get_config() config.update({ "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": self.rope_theta, "layer_idx": self.layer_idx }) return config @keras.saving.register_keras_serializable() class SAM1Model(keras.Model): def __init__(self, **kwargs): super().__init__() if 'config' in kwargs and isinstance(kwargs['config'], dict): self.cfg = kwargs['config'] elif 'vocab_size' in kwargs: self.cfg = kwargs else: self.cfg = kwargs.get('cfg', kwargs) self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) block_args = { 'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta'] } self.blocks = [] for i in range(self.cfg['n_layers']): block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) self.blocks.append(block) self.norm = RMSNorm(name="final_norm") self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") def call(self, input_ids, training=None): x = self.embed(input_ids) for block in self.blocks: x = block(x, training=training) return self.lm_head(self.norm(x)) def get_config(self): base_config = super().get_config() base_config['config'] = self.cfg return base_config # ============================================================================ # Global State # ============================================================================ model = None tokenizer = None config = None eos_token_id = None fast_forward = None MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow" CACHE_DIR = "./model_cache" # Stats worker_stats = { "total_requests": 0, "total_tokens": 0, "decode_requests": 0, "uptime_start": time.time() } # ============================================================================ # Request Models # ============================================================================ class GenerateRequest(BaseModel): prompt: str max_tokens: int = 512 temperature: float = 0.8 top_k: int = 40 top_p: float = 0.9 repetition_penalty: float = 1.1 stream: bool = False return_token_ids: bool = False class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[ChatMessage] max_tokens: int = 512 temperature: float = 0.8 top_k: int = 40 top_p: float = 0.9 repetition_penalty: float = 1.1 stream: bool = False return_token_ids: bool = False class DecodeRequest(BaseModel): token_ids: List[int] class BatchDecodeRequest(BaseModel): batches: List[List[int]] # ============================================================================ # Generation Functions # ============================================================================ def generate_tokens( prompt: str, max_tokens: int = 512, temperature: float = 0.8, top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1, return_token_ids: bool = False ): """Core generation - yields (token_id, token_text or None)""" global model, tokenizer, config, eos_token_id, fast_forward input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id] if len(input_ids) == 0: return if len(input_ids) > config['max_position_embeddings'] - max_tokens: input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):] input_tensor = tf.constant([input_ids], dtype=tf.int32) token_freq = {} for step in range(max_tokens): logits = fast_forward(input_tensor) next_token_logits = logits[0, -1, :].numpy() next_token_logits = next_token_logits / temperature if repetition_penalty != 1.0: for token_id, freq in token_freq.items(): if token_id < len(next_token_logits): next_token_logits[token_id] /= (repetition_penalty ** freq) if top_k > 0: top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:] top_k_logits = next_token_logits[top_k_indices] top_k_probs = tf.nn.softmax(top_k_logits).numpy() if top_p < 1.0: sorted_indices = np.argsort(top_k_probs)[::-1] cumsum = np.cumsum(top_k_probs[sorted_indices]) cutoff_idx = np.searchsorted(cumsum, top_p) nucleus_indices = sorted_indices[:cutoff_idx + 1] nucleus_logits = top_k_logits[nucleus_indices] nucleus_probs = tf.nn.softmax(nucleus_logits).numpy() sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs) next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]]) else: sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs) next_token_id = int(top_k_indices[sampled_idx]) else: probs = tf.nn.softmax(next_token_logits).numpy() next_token_id = np.random.choice(len(probs), p=probs) if next_token_id == eos_token_id: break token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1 if return_token_ids: yield (next_token_id, None) else: token_text = tokenizer.decode([next_token_id]) yield (next_token_id, token_text) input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1) if input_tensor.shape[1] > config['max_position_embeddings']: input_tensor = input_tensor[:, -config['max_position_embeddings']:] def format_chat_prompt(messages: List[ChatMessage]) -> str: prompt = "" for msg in messages: if msg.role == "user": prompt += f"<|im_start|>user\n{msg.content}<|im_end|>\n" elif msg.role == "assistant": prompt += f"<|im_start|>assistant\n{msg.content}<|im_end|>\n" prompt += "<|im_start|>assistant\n" return prompt # ============================================================================ # Status Page # ============================================================================ @app.get("/", response_class=HTMLResponse) async def status_page(): """Worker status page""" return """