|
|
""" |
|
|
SAM-Z-1 Distributed Worker Node v4.0 |
|
|
Optimized for distributed gen/decode pipeline |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import StreamingResponse, HTMLResponse |
|
|
from pydantic import BaseModel |
|
|
import tensorflow as tf |
|
|
import keras |
|
|
from huggingface_hub import hf_hub_download |
|
|
import json |
|
|
import os |
|
|
from tokenizers import Tokenizer |
|
|
import numpy as np |
|
|
import time |
|
|
from typing import List, Optional |
|
|
import asyncio |
|
|
|
|
|
app = FastAPI(title="SAM-Z-1 Distributed Worker", version="4.0.0") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@keras.saving.register_keras_serializable() |
|
|
class RotaryEmbedding(keras.layers.Layer): |
|
|
def __init__(self, dim, max_len=2048, theta=10000, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.dim = dim |
|
|
self.max_len = max_len |
|
|
self.theta = theta |
|
|
self.built_cache = False |
|
|
|
|
|
def build(self, input_shape): |
|
|
super().build(input_shape) |
|
|
|
|
|
def _build_cache(self): |
|
|
if not self.built_cache: |
|
|
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) |
|
|
t = tf.range(self.max_len, dtype=tf.float32) |
|
|
freqs = tf.einsum("i,j->ij", t, inv_freq) |
|
|
emb = tf.concat([freqs, freqs], axis=-1) |
|
|
|
|
|
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) |
|
|
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) |
|
|
self.built_cache = True |
|
|
|
|
|
def rotate_half(self, x): |
|
|
x1, x2 = tf.split(x, 2, axis=-1) |
|
|
return tf.concat([-x2, x1], axis=-1) |
|
|
|
|
|
def call(self, q, k): |
|
|
self._build_cache() |
|
|
seq_len = tf.shape(q)[2] |
|
|
dtype = q.dtype |
|
|
cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :] |
|
|
sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :] |
|
|
|
|
|
q_rotated = (q * cos) + (self.rotate_half(q) * sin) |
|
|
k_rotated = (k * cos) + (self.rotate_half(k) * sin) |
|
|
|
|
|
return q_rotated, k_rotated |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) |
|
|
return config |
|
|
|
|
|
@keras.saving.register_keras_serializable() |
|
|
class RMSNorm(keras.layers.Layer): |
|
|
def __init__(self, epsilon=1e-5, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.epsilon = epsilon |
|
|
|
|
|
def build(self, input_shape): |
|
|
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") |
|
|
|
|
|
def call(self, x): |
|
|
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) |
|
|
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({"epsilon": self.epsilon}) |
|
|
return config |
|
|
|
|
|
@keras.saving.register_keras_serializable() |
|
|
class TransformerBlock(keras.layers.Layer): |
|
|
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.ff_dim = ff_dim |
|
|
self.dropout_rate = dropout |
|
|
self.max_len = max_len |
|
|
self.rope_theta = rope_theta |
|
|
self.head_dim = d_model // n_heads |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
self.pre_attn_norm = RMSNorm() |
|
|
self.pre_ffn_norm = RMSNorm() |
|
|
|
|
|
self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj") |
|
|
self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj") |
|
|
self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj") |
|
|
self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj") |
|
|
|
|
|
self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta) |
|
|
|
|
|
self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj") |
|
|
self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj") |
|
|
self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj") |
|
|
|
|
|
self.dropout = keras.layers.Dropout(dropout) |
|
|
|
|
|
def call(self, x, training=None): |
|
|
B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model |
|
|
dtype = x.dtype |
|
|
|
|
|
res = x |
|
|
y = self.pre_attn_norm(x) |
|
|
|
|
|
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
|
|
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
|
|
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
|
|
|
|
|
q, k = self.rope(q, k) |
|
|
|
|
|
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) |
|
|
mask = tf.where( |
|
|
tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, |
|
|
tf.constant(-1e9, dtype=dtype), |
|
|
tf.constant(0.0, dtype=dtype) |
|
|
) |
|
|
scores += mask |
|
|
attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v) |
|
|
|
|
|
attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D]) |
|
|
x = res + self.dropout(self.out_proj(attn), training=training) |
|
|
|
|
|
res = x |
|
|
y = self.pre_ffn_norm(x) |
|
|
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) |
|
|
|
|
|
return res + self.dropout(ffn, training=training) |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"d_model": self.d_model, |
|
|
"n_heads": self.n_heads, |
|
|
"ff_dim": self.ff_dim, |
|
|
"dropout": self.dropout_rate, |
|
|
"max_len": self.max_len, |
|
|
"rope_theta": self.rope_theta, |
|
|
"layer_idx": self.layer_idx |
|
|
}) |
|
|
return config |
|
|
|
|
|
@keras.saving.register_keras_serializable() |
|
|
class SAM1Model(keras.Model): |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__() |
|
|
if 'config' in kwargs and isinstance(kwargs['config'], dict): |
|
|
self.cfg = kwargs['config'] |
|
|
elif 'vocab_size' in kwargs: |
|
|
self.cfg = kwargs |
|
|
else: |
|
|
self.cfg = kwargs.get('cfg', kwargs) |
|
|
|
|
|
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") |
|
|
|
|
|
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) |
|
|
block_args = { |
|
|
'd_model': self.cfg['d_model'], |
|
|
'n_heads': self.cfg['n_heads'], |
|
|
'ff_dim': ff_dim, |
|
|
'dropout': self.cfg['dropout'], |
|
|
'max_len': self.cfg['max_len'], |
|
|
'rope_theta': self.cfg['rope_theta'] |
|
|
} |
|
|
|
|
|
self.blocks = [] |
|
|
for i in range(self.cfg['n_layers']): |
|
|
block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) |
|
|
self.blocks.append(block) |
|
|
|
|
|
self.norm = RMSNorm(name="final_norm") |
|
|
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") |
|
|
|
|
|
def call(self, input_ids, training=None): |
|
|
x = self.embed(input_ids) |
|
|
for block in self.blocks: |
|
|
x = block(x, training=training) |
|
|
return self.lm_head(self.norm(x)) |
|
|
|
|
|
def get_config(self): |
|
|
base_config = super().get_config() |
|
|
base_config['config'] = self.cfg |
|
|
return base_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
config = None |
|
|
eos_token_id = None |
|
|
fast_forward = None |
|
|
|
|
|
MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow" |
|
|
CACHE_DIR = "./model_cache" |
|
|
|
|
|
|
|
|
worker_stats = { |
|
|
"total_requests": 0, |
|
|
"total_tokens": 0, |
|
|
"decode_requests": 0, |
|
|
"uptime_start": time.time() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str |
|
|
max_tokens: int = 512 |
|
|
temperature: float = 0.8 |
|
|
top_k: int = 40 |
|
|
top_p: float = 0.9 |
|
|
repetition_penalty: float = 1.1 |
|
|
stream: bool = False |
|
|
return_token_ids: bool = False |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[ChatMessage] |
|
|
max_tokens: int = 512 |
|
|
temperature: float = 0.8 |
|
|
top_k: int = 40 |
|
|
top_p: float = 0.9 |
|
|
repetition_penalty: float = 1.1 |
|
|
stream: bool = False |
|
|
return_token_ids: bool = False |
|
|
|
|
|
class DecodeRequest(BaseModel): |
|
|
token_ids: List[int] |
|
|
|
|
|
class BatchDecodeRequest(BaseModel): |
|
|
batches: List[List[int]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_tokens( |
|
|
prompt: str, |
|
|
max_tokens: int = 512, |
|
|
temperature: float = 0.8, |
|
|
top_k: int = 40, |
|
|
top_p: float = 0.9, |
|
|
repetition_penalty: float = 1.1, |
|
|
return_token_ids: bool = False |
|
|
): |
|
|
"""Core generation - yields (token_id, token_text or None)""" |
|
|
global model, tokenizer, config, eos_token_id, fast_forward |
|
|
|
|
|
input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id] |
|
|
|
|
|
if len(input_ids) == 0: |
|
|
return |
|
|
|
|
|
if len(input_ids) > config['max_position_embeddings'] - max_tokens: |
|
|
input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):] |
|
|
|
|
|
input_tensor = tf.constant([input_ids], dtype=tf.int32) |
|
|
token_freq = {} |
|
|
|
|
|
for step in range(max_tokens): |
|
|
logits = fast_forward(input_tensor) |
|
|
next_token_logits = logits[0, -1, :].numpy() |
|
|
|
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
if repetition_penalty != 1.0: |
|
|
for token_id, freq in token_freq.items(): |
|
|
if token_id < len(next_token_logits): |
|
|
next_token_logits[token_id] /= (repetition_penalty ** freq) |
|
|
|
|
|
if top_k > 0: |
|
|
top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:] |
|
|
top_k_logits = next_token_logits[top_k_indices] |
|
|
top_k_probs = tf.nn.softmax(top_k_logits).numpy() |
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_indices = np.argsort(top_k_probs)[::-1] |
|
|
cumsum = np.cumsum(top_k_probs[sorted_indices]) |
|
|
cutoff_idx = np.searchsorted(cumsum, top_p) |
|
|
nucleus_indices = sorted_indices[:cutoff_idx + 1] |
|
|
|
|
|
nucleus_logits = top_k_logits[nucleus_indices] |
|
|
nucleus_probs = tf.nn.softmax(nucleus_logits).numpy() |
|
|
|
|
|
sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs) |
|
|
next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]]) |
|
|
else: |
|
|
sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs) |
|
|
next_token_id = int(top_k_indices[sampled_idx]) |
|
|
else: |
|
|
probs = tf.nn.softmax(next_token_logits).numpy() |
|
|
next_token_id = np.random.choice(len(probs), p=probs) |
|
|
|
|
|
if next_token_id == eos_token_id: |
|
|
break |
|
|
|
|
|
token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1 |
|
|
|
|
|
if return_token_ids: |
|
|
yield (next_token_id, None) |
|
|
else: |
|
|
token_text = tokenizer.decode([next_token_id]) |
|
|
yield (next_token_id, token_text) |
|
|
|
|
|
input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1) |
|
|
|
|
|
if input_tensor.shape[1] > config['max_position_embeddings']: |
|
|
input_tensor = input_tensor[:, -config['max_position_embeddings']:] |
|
|
|
|
|
def format_chat_prompt(messages: List[ChatMessage]) -> str: |
|
|
prompt = "" |
|
|
for msg in messages: |
|
|
if msg.role == "user": |
|
|
prompt += f"<|im_start|>user\n{msg.content}<|im_end|>\n" |
|
|
elif msg.role == "assistant": |
|
|
prompt += f"<|im_start|>assistant\n{msg.content}<|im_end|>\n" |
|
|
|
|
|
prompt += "<|im_start|>assistant\n" |
|
|
return prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def status_page(): |
|
|
"""Worker status page""" |
|
|
return """ |
|
|
<!DOCTYPE html> |
|
|
<html> |
|
|
<head> |
|
|
<title>SAM-Z-1 Worker Node</title> |
|
|
<style> |
|
|
* { margin: 0; padding: 0; box-sizing: border-box; } |
|
|
body { |
|
|
font-family: 'Courier New', monospace; |
|
|
background: linear-gradient(135deg, #1a1f3a 0%, #0a0e27 100%); |
|
|
color: #00bfff; |
|
|
padding: 20px; |
|
|
min-height: 100vh; |
|
|
} |
|
|
.container { |
|
|
max-width: 900px; |
|
|
margin: 0 auto; |
|
|
} |
|
|
.header { |
|
|
text-align: center; |
|
|
padding: 30px; |
|
|
background: rgba(0, 191, 255, 0.1); |
|
|
border: 2px solid #00bfff; |
|
|
border-radius: 10px; |
|
|
margin-bottom: 30px; |
|
|
box-shadow: 0 0 20px rgba(0, 191, 255, 0.3); |
|
|
} |
|
|
.header h1 { |
|
|
font-size: 2.5em; |
|
|
text-transform: uppercase; |
|
|
letter-spacing: 3px; |
|
|
animation: glow 2s ease-in-out infinite alternate; |
|
|
} |
|
|
@keyframes glow { |
|
|
from { text-shadow: 0 0 10px #00bfff; } |
|
|
to { text-shadow: 0 0 20px #00bfff, 0 0 30px #00bfff; } |
|
|
} |
|
|
.badge { |
|
|
display: inline-block; |
|
|
padding: 5px 15px; |
|
|
border-radius: 15px; |
|
|
font-size: 0.9em; |
|
|
margin-top: 10px; |
|
|
} |
|
|
.badge-ready { |
|
|
background: rgba(0, 255, 136, 0.2); |
|
|
border: 1px solid #00ff88; |
|
|
color: #00ff88; |
|
|
} |
|
|
.badge-loading { |
|
|
background: rgba(255, 165, 0, 0.2); |
|
|
border: 1px solid #ffa500; |
|
|
color: #ffa500; |
|
|
} |
|
|
.stats-grid { |
|
|
display: grid; |
|
|
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); |
|
|
gap: 20px; |
|
|
margin-bottom: 30px; |
|
|
} |
|
|
.stat-card { |
|
|
background: rgba(0, 191, 255, 0.05); |
|
|
border: 1px solid #00bfff; |
|
|
border-radius: 8px; |
|
|
padding: 20px; |
|
|
text-align: center; |
|
|
} |
|
|
.stat-label { |
|
|
font-size: 0.8em; |
|
|
opacity: 0.7; |
|
|
text-transform: uppercase; |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
.stat-value { |
|
|
font-size: 2em; |
|
|
font-weight: bold; |
|
|
} |
|
|
.features { |
|
|
background: rgba(0, 191, 255, 0.05); |
|
|
border: 1px solid #00bfff; |
|
|
border-radius: 8px; |
|
|
padding: 20px; |
|
|
} |
|
|
.features h3 { |
|
|
margin-bottom: 15px; |
|
|
} |
|
|
.feature-list { |
|
|
list-style: none; |
|
|
padding: 0; |
|
|
} |
|
|
.feature-list li { |
|
|
padding: 10px; |
|
|
margin: 5px 0; |
|
|
background: rgba(0, 191, 255, 0.1); |
|
|
border-radius: 5px; |
|
|
} |
|
|
.feature-list li:before { |
|
|
content: "β‘ "; |
|
|
color: #00ff88; |
|
|
} |
|
|
.timestamp { |
|
|
text-align: center; |
|
|
margin-top: 20px; |
|
|
opacity: 0.5; |
|
|
} |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<div class="container"> |
|
|
<div class="header"> |
|
|
<h1>βοΈ WORKER NODE βοΈ</h1> |
|
|
<div>SAM-Z-1 Distributed Worker v4.0</div> |
|
|
<div class="badge" id="status-badge">CHECKING STATUS...</div> |
|
|
</div> |
|
|
|
|
|
<div class="stats-grid" id="stats"> |
|
|
<div class="stat-card"> |
|
|
<div class="stat-label">Total Requests</div> |
|
|
<div class="stat-value" id="total-req">--</div> |
|
|
</div> |
|
|
<div class="stat-card"> |
|
|
<div class="stat-label">Total Tokens</div> |
|
|
<div class="stat-value" id="total-tokens">--</div> |
|
|
</div> |
|
|
<div class="stat-card"> |
|
|
<div class="stat-label">Decode Requests</div> |
|
|
<div class="stat-value" id="decode-req">--</div> |
|
|
</div> |
|
|
<div class="stat-card"> |
|
|
<div class="stat-label">Uptime</div> |
|
|
<div class="stat-value" id="uptime">--</div> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<div class="features"> |
|
|
<h3>π CAPABILITIES</h3> |
|
|
<ul class="feature-list"> |
|
|
<li>Full Text Generation</li> |
|
|
<li>Token-Only Mode (for distributed pipeline)</li> |
|
|
<li>High-Speed Batch Decoding</li> |
|
|
<li>Chat Completion</li> |
|
|
<li>Streaming & Non-Streaming</li> |
|
|
</ul> |
|
|
</div> |
|
|
|
|
|
<div class="timestamp" id="timestamp">Initializing...</div> |
|
|
</div> |
|
|
|
|
|
<script> |
|
|
async function updateStats() { |
|
|
try { |
|
|
const response = await fetch('/health'); |
|
|
const data = await response.json(); |
|
|
|
|
|
const badge = document.getElementById('status-badge'); |
|
|
if (data.model_loaded) { |
|
|
badge.textContent = 'β
READY FOR INFERENCE'; |
|
|
badge.className = 'badge badge-ready'; |
|
|
} else { |
|
|
badge.textContent = 'β³ LOADING MODEL...'; |
|
|
badge.className = 'badge badge-loading'; |
|
|
} |
|
|
|
|
|
// Fetch stats |
|
|
const statsRes = await fetch('/stats'); |
|
|
const stats = await statsRes.json(); |
|
|
|
|
|
document.getElementById('total-req').textContent = stats.total_requests; |
|
|
document.getElementById('total-tokens').textContent = stats.total_tokens; |
|
|
document.getElementById('decode-req').textContent = stats.decode_requests; |
|
|
|
|
|
const uptime = Math.floor(stats.uptime); |
|
|
const h = Math.floor(uptime / 3600); |
|
|
const m = Math.floor((uptime % 3600) / 60); |
|
|
const s = uptime % 60; |
|
|
document.getElementById('uptime').textContent = `${h}h ${m}m ${s}s`; |
|
|
|
|
|
document.getElementById('timestamp').textContent = |
|
|
`Last update: ${new Date().toLocaleTimeString()}`; |
|
|
} catch (e) { |
|
|
console.error('Failed to update stats:', e); |
|
|
} |
|
|
} |
|
|
|
|
|
// Update every second |
|
|
setInterval(updateStats, 1000); |
|
|
updateStats(); |
|
|
</script> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return { |
|
|
"status": "healthy" if model is not None else "loading", |
|
|
"model_loaded": model is not None |
|
|
} |
|
|
|
|
|
@app.get("/stats") |
|
|
async def stats(): |
|
|
uptime = time.time() - worker_stats["uptime_start"] |
|
|
return { |
|
|
"total_requests": worker_stats["total_requests"], |
|
|
"total_tokens": worker_stats["total_tokens"], |
|
|
"decode_requests": worker_stats["decode_requests"], |
|
|
"uptime": uptime, |
|
|
"tokens_per_second": worker_stats["total_tokens"] / uptime if uptime > 0 else 0 |
|
|
} |
|
|
|
|
|
@app.post("/decode") |
|
|
async def decode(request: DecodeRequest): |
|
|
"""Fast single decode""" |
|
|
if tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Tokenizer not loaded") |
|
|
|
|
|
try: |
|
|
worker_stats["decode_requests"] += 1 |
|
|
text = tokenizer.decode(request.token_ids) |
|
|
return {"text": text} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Decode error: {str(e)}") |
|
|
|
|
|
@app.post("/decode/batch") |
|
|
async def batch_decode(request: BatchDecodeRequest): |
|
|
"""Optimized batch decoding for distributed pipeline""" |
|
|
if tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Tokenizer not loaded") |
|
|
|
|
|
try: |
|
|
worker_stats["decode_requests"] += len(request.batches) |
|
|
results = [tokenizer.decode(batch) for batch in request.batches] |
|
|
return {"texts": results} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Batch decode error: {str(e)}") |
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate(request: GenerateRequest): |
|
|
"""Generate text""" |
|
|
if model is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
worker_stats["total_requests"] += 1 |
|
|
start_time = time.time() |
|
|
|
|
|
if request.stream: |
|
|
async def stream_tokens(): |
|
|
generated_text = "" |
|
|
token_count = 0 |
|
|
|
|
|
try: |
|
|
for token_id, token_text in generate_tokens( |
|
|
request.prompt, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
top_k=request.top_k, |
|
|
top_p=request.top_p, |
|
|
repetition_penalty=request.repetition_penalty, |
|
|
return_token_ids=request.return_token_ids |
|
|
): |
|
|
token_count += 1 |
|
|
worker_stats["total_tokens"] += 1 |
|
|
|
|
|
if request.return_token_ids: |
|
|
yield f"data: {json.dumps({'token_id': token_id})}\n\n" |
|
|
else: |
|
|
generated_text += token_text |
|
|
yield f"data: {json.dumps({'text': token_text, 'total': generated_text})}\n\n" |
|
|
|
|
|
await asyncio.sleep(0.001) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n" |
|
|
|
|
|
except Exception as e: |
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n" |
|
|
|
|
|
return StreamingResponse(stream_tokens(), media_type="text/event-stream") |
|
|
|
|
|
else: |
|
|
generated_text = "" |
|
|
token_count = 0 |
|
|
|
|
|
try: |
|
|
for token_id, token_text in generate_tokens( |
|
|
request.prompt, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
top_k=request.top_k, |
|
|
top_p=request.top_p, |
|
|
repetition_penalty=request.repetition_penalty, |
|
|
return_token_ids=request.return_token_ids |
|
|
): |
|
|
if not request.return_token_ids: |
|
|
generated_text += token_text |
|
|
token_count += 1 |
|
|
worker_stats["total_tokens"] += 1 |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
|
|
|
return { |
|
|
"text": generated_text, |
|
|
"tokens": token_count, |
|
|
"time": elapsed, |
|
|
"tokens_per_second": token_count / elapsed if elapsed > 0 else 0 |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") |
|
|
|
|
|
@app.post("/chat") |
|
|
async def chat(request: ChatRequest): |
|
|
"""Chat completion""" |
|
|
if model is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
worker_stats["total_requests"] += 1 |
|
|
prompt = format_chat_prompt(request.messages) |
|
|
start_time = time.time() |
|
|
|
|
|
if request.stream: |
|
|
async def stream_tokens(): |
|
|
generated_text = "" |
|
|
token_count = 0 |
|
|
|
|
|
try: |
|
|
for token_id, token_text in generate_tokens( |
|
|
prompt, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
top_k=request.top_k, |
|
|
top_p=request.top_p, |
|
|
repetition_penalty=request.repetition_penalty, |
|
|
return_token_ids=request.return_token_ids |
|
|
): |
|
|
token_count += 1 |
|
|
worker_stats["total_tokens"] += 1 |
|
|
|
|
|
if request.return_token_ids: |
|
|
yield f"data: {json.dumps({'token_id': token_id})}\n\n" |
|
|
else: |
|
|
generated_text += token_text |
|
|
|
|
|
if "<|im_end|>" in generated_text: |
|
|
generated_text = generated_text.split("<|im_end|>")[0] |
|
|
break |
|
|
|
|
|
yield f"data: {json.dumps({'delta': token_text, 'content': generated_text})}\n\n" |
|
|
|
|
|
await asyncio.sleep(0.001) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n" |
|
|
|
|
|
except Exception as e: |
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n" |
|
|
|
|
|
return StreamingResponse(stream_tokens(), media_type="text/event-stream") |
|
|
|
|
|
else: |
|
|
generated_text = "" |
|
|
token_count = 0 |
|
|
|
|
|
try: |
|
|
for token_id, token_text in generate_tokens( |
|
|
prompt, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
top_k=request.top_k, |
|
|
top_p=request.top_p, |
|
|
repetition_penalty=request.repetition_penalty, |
|
|
return_token_ids=request.return_token_ids |
|
|
): |
|
|
if not request.return_token_ids: |
|
|
generated_text += token_text |
|
|
|
|
|
if "<|im_end|>" in generated_text: |
|
|
generated_text = generated_text.split("<|im_end|>")[0] |
|
|
break |
|
|
|
|
|
token_count += 1 |
|
|
worker_stats["total_tokens"] += 1 |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
|
|
|
return { |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": generated_text.strip() |
|
|
}, |
|
|
"tokens": token_count, |
|
|
"time": elapsed, |
|
|
"tokens_per_second": token_count / elapsed if elapsed > 0 else 0 |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
global model, tokenizer, config, eos_token_id, fast_forward |
|
|
|
|
|
print("π Loading SAM-Z-1 Model...") |
|
|
|
|
|
try: |
|
|
config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR) |
|
|
|
|
|
try: |
|
|
weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR) |
|
|
print("β
Found checkpoint weights") |
|
|
use_checkpoint = True |
|
|
except: |
|
|
print("β οΈ Checkpoint not found, using model.keras") |
|
|
model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR) |
|
|
use_checkpoint = False |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
print(f"π¦ Config loaded: {config['num_hidden_layers']} layers") |
|
|
|
|
|
print("π¦ Creating tokenizer...") |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"] |
|
|
hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens}) |
|
|
|
|
|
os.makedirs("./temp_tokenizer", exist_ok=True) |
|
|
hf_tokenizer.save_pretrained("./temp_tokenizer") |
|
|
tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json") |
|
|
|
|
|
eos_token_id = config.get('eos_token_id', 50256) |
|
|
|
|
|
print(f"β
Tokenizer ready: vocab size {tokenizer.get_vocab_size()}") |
|
|
|
|
|
print("π Loading model...") |
|
|
|
|
|
if use_checkpoint: |
|
|
model_config = { |
|
|
'vocab_size': config['vocab_size'], |
|
|
'd_model': config['hidden_size'], |
|
|
'n_layers': config['num_hidden_layers'], |
|
|
'n_heads': config['num_attention_heads'], |
|
|
'ff_mult': config['intermediate_size'] / config['hidden_size'], |
|
|
'max_len': config['max_position_embeddings'], |
|
|
'dropout': 0.1, |
|
|
'rope_theta': config['rope_theta'] |
|
|
} |
|
|
|
|
|
model = SAM1Model(config=model_config) |
|
|
dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32) |
|
|
_ = model(dummy_input, training=False) |
|
|
|
|
|
print(f"β
Architecture built: {model.count_params():,} parameters") |
|
|
|
|
|
model.load_weights(weights_path) |
|
|
print("β
Weights loaded!") |
|
|
|
|
|
else: |
|
|
model = keras.models.load_model(model_path, compile=False) |
|
|
print("β
Model loaded!") |
|
|
|
|
|
@tf.function(reduce_retracing=True) |
|
|
def optimized_forward(input_tensor): |
|
|
return model(input_tensor, training=False) |
|
|
|
|
|
fast_forward = optimized_forward |
|
|
|
|
|
print("β
SAM-Z-1 Distributed Worker ready! π") |
|
|
print("π₯ Features enabled:") |
|
|
print(" - Full text generation") |
|
|
print(" - Token-only mode (distributed pipeline)") |
|
|
print(" - Batch decoding optimization") |
|
|
print(" - Streaming support") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Failed to load model: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=7860, |
|
|
log_level="info" |
|
|
) |