Worker-2 / app.py
Bc-AI's picture
Update app.py
25388aa verified
"""
SAM-Z-1 Distributed Worker Node v4.0
Optimized for distributed gen/decode pipeline
"""
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
import tensorflow as tf
import keras
from huggingface_hub import hf_hub_download
import json
import os
from tokenizers import Tokenizer
import numpy as np
import time
from typing import List, Optional
import asyncio
app = FastAPI(title="SAM-Z-1 Distributed Worker", version="4.0.0")
# ============================================================================
# Model Architecture
# ============================================================================
@keras.saving.register_keras_serializable()
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_len = max_len
self.theta = theta
self.built_cache = False
def build(self, input_shape):
super().build(input_shape)
def _build_cache(self):
if not self.built_cache:
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
t = tf.range(self.max_len, dtype=tf.float32)
freqs = tf.einsum("i,j->ij", t, inv_freq)
emb = tf.concat([freqs, freqs], axis=-1)
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
self.built_cache = True
def rotate_half(self, x):
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat([-x2, x1], axis=-1)
def call(self, q, k):
self._build_cache()
seq_len = tf.shape(q)[2]
dtype = q.dtype
cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
q_rotated = (q * cos) + (self.rotate_half(q) * sin)
k_rotated = (k * cos) + (self.rotate_half(k) * sin)
return q_rotated, k_rotated
def get_config(self):
config = super().get_config()
config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
return config
@keras.saving.register_keras_serializable()
class RMSNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
def call(self, x):
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
def get_config(self):
config = super().get_config()
config.update({"epsilon": self.epsilon})
return config
@keras.saving.register_keras_serializable()
class TransformerBlock(keras.layers.Layer):
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout
self.max_len = max_len
self.rope_theta = rope_theta
self.head_dim = d_model // n_heads
self.layer_idx = layer_idx
self.pre_attn_norm = RMSNorm()
self.pre_ffn_norm = RMSNorm()
self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
self.dropout = keras.layers.Dropout(dropout)
def call(self, x, training=None):
B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
dtype = x.dtype
res = x
y = self.pre_attn_norm(x)
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
q, k = self.rope(q, k)
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
mask = tf.where(
tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
tf.constant(-1e9, dtype=dtype),
tf.constant(0.0, dtype=dtype)
)
scores += mask
attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
x = res + self.dropout(self.out_proj(attn), training=training)
res = x
y = self.pre_ffn_norm(x)
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
return res + self.dropout(ffn, training=training)
def get_config(self):
config = super().get_config()
config.update({
"d_model": self.d_model,
"n_heads": self.n_heads,
"ff_dim": self.ff_dim,
"dropout": self.dropout_rate,
"max_len": self.max_len,
"rope_theta": self.rope_theta,
"layer_idx": self.layer_idx
})
return config
@keras.saving.register_keras_serializable()
class SAM1Model(keras.Model):
def __init__(self, **kwargs):
super().__init__()
if 'config' in kwargs and isinstance(kwargs['config'], dict):
self.cfg = kwargs['config']
elif 'vocab_size' in kwargs:
self.cfg = kwargs
else:
self.cfg = kwargs.get('cfg', kwargs)
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
block_args = {
'd_model': self.cfg['d_model'],
'n_heads': self.cfg['n_heads'],
'ff_dim': ff_dim,
'dropout': self.cfg['dropout'],
'max_len': self.cfg['max_len'],
'rope_theta': self.cfg['rope_theta']
}
self.blocks = []
for i in range(self.cfg['n_layers']):
block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
self.blocks.append(block)
self.norm = RMSNorm(name="final_norm")
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
def call(self, input_ids, training=None):
x = self.embed(input_ids)
for block in self.blocks:
x = block(x, training=training)
return self.lm_head(self.norm(x))
def get_config(self):
base_config = super().get_config()
base_config['config'] = self.cfg
return base_config
# ============================================================================
# Global State
# ============================================================================
model = None
tokenizer = None
config = None
eos_token_id = None
fast_forward = None
MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow"
CACHE_DIR = "./model_cache"
# Stats
worker_stats = {
"total_requests": 0,
"total_tokens": 0,
"decode_requests": 0,
"uptime_start": time.time()
}
# ============================================================================
# Request Models
# ============================================================================
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.8
top_k: int = 40
top_p: float = 0.9
repetition_penalty: float = 1.1
stream: bool = False
return_token_ids: bool = False
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: int = 512
temperature: float = 0.8
top_k: int = 40
top_p: float = 0.9
repetition_penalty: float = 1.1
stream: bool = False
return_token_ids: bool = False
class DecodeRequest(BaseModel):
token_ids: List[int]
class BatchDecodeRequest(BaseModel):
batches: List[List[int]]
# ============================================================================
# Generation Functions
# ============================================================================
def generate_tokens(
prompt: str,
max_tokens: int = 512,
temperature: float = 0.8,
top_k: int = 40,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
return_token_ids: bool = False
):
"""Core generation - yields (token_id, token_text or None)"""
global model, tokenizer, config, eos_token_id, fast_forward
input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
if len(input_ids) == 0:
return
if len(input_ids) > config['max_position_embeddings'] - max_tokens:
input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):]
input_tensor = tf.constant([input_ids], dtype=tf.int32)
token_freq = {}
for step in range(max_tokens):
logits = fast_forward(input_tensor)
next_token_logits = logits[0, -1, :].numpy()
next_token_logits = next_token_logits / temperature
if repetition_penalty != 1.0:
for token_id, freq in token_freq.items():
if token_id < len(next_token_logits):
next_token_logits[token_id] /= (repetition_penalty ** freq)
if top_k > 0:
top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
top_k_logits = next_token_logits[top_k_indices]
top_k_probs = tf.nn.softmax(top_k_logits).numpy()
if top_p < 1.0:
sorted_indices = np.argsort(top_k_probs)[::-1]
cumsum = np.cumsum(top_k_probs[sorted_indices])
cutoff_idx = np.searchsorted(cumsum, top_p)
nucleus_indices = sorted_indices[:cutoff_idx + 1]
nucleus_logits = top_k_logits[nucleus_indices]
nucleus_probs = tf.nn.softmax(nucleus_logits).numpy()
sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs)
next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]])
else:
sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs)
next_token_id = int(top_k_indices[sampled_idx])
else:
probs = tf.nn.softmax(next_token_logits).numpy()
next_token_id = np.random.choice(len(probs), p=probs)
if next_token_id == eos_token_id:
break
token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
if return_token_ids:
yield (next_token_id, None)
else:
token_text = tokenizer.decode([next_token_id])
yield (next_token_id, token_text)
input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
if input_tensor.shape[1] > config['max_position_embeddings']:
input_tensor = input_tensor[:, -config['max_position_embeddings']:]
def format_chat_prompt(messages: List[ChatMessage]) -> str:
prompt = ""
for msg in messages:
if msg.role == "user":
prompt += f"<|im_start|>user\n{msg.content}<|im_end|>\n"
elif msg.role == "assistant":
prompt += f"<|im_start|>assistant\n{msg.content}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
return prompt
# ============================================================================
# Status Page
# ============================================================================
@app.get("/", response_class=HTMLResponse)
async def status_page():
"""Worker status page"""
return """
<!DOCTYPE html>
<html>
<head>
<title>SAM-Z-1 Worker Node</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: 'Courier New', monospace;
background: linear-gradient(135deg, #1a1f3a 0%, #0a0e27 100%);
color: #00bfff;
padding: 20px;
min-height: 100vh;
}
.container {
max-width: 900px;
margin: 0 auto;
}
.header {
text-align: center;
padding: 30px;
background: rgba(0, 191, 255, 0.1);
border: 2px solid #00bfff;
border-radius: 10px;
margin-bottom: 30px;
box-shadow: 0 0 20px rgba(0, 191, 255, 0.3);
}
.header h1 {
font-size: 2.5em;
text-transform: uppercase;
letter-spacing: 3px;
animation: glow 2s ease-in-out infinite alternate;
}
@keyframes glow {
from { text-shadow: 0 0 10px #00bfff; }
to { text-shadow: 0 0 20px #00bfff, 0 0 30px #00bfff; }
}
.badge {
display: inline-block;
padding: 5px 15px;
border-radius: 15px;
font-size: 0.9em;
margin-top: 10px;
}
.badge-ready {
background: rgba(0, 255, 136, 0.2);
border: 1px solid #00ff88;
color: #00ff88;
}
.badge-loading {
background: rgba(255, 165, 0, 0.2);
border: 1px solid #ffa500;
color: #ffa500;
}
.stats-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 20px;
margin-bottom: 30px;
}
.stat-card {
background: rgba(0, 191, 255, 0.05);
border: 1px solid #00bfff;
border-radius: 8px;
padding: 20px;
text-align: center;
}
.stat-label {
font-size: 0.8em;
opacity: 0.7;
text-transform: uppercase;
margin-bottom: 10px;
}
.stat-value {
font-size: 2em;
font-weight: bold;
}
.features {
background: rgba(0, 191, 255, 0.05);
border: 1px solid #00bfff;
border-radius: 8px;
padding: 20px;
}
.features h3 {
margin-bottom: 15px;
}
.feature-list {
list-style: none;
padding: 0;
}
.feature-list li {
padding: 10px;
margin: 5px 0;
background: rgba(0, 191, 255, 0.1);
border-radius: 5px;
}
.feature-list li:before {
content: "⚑ ";
color: #00ff88;
}
.timestamp {
text-align: center;
margin-top: 20px;
opacity: 0.5;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>βš™οΈ WORKER NODE βš™οΈ</h1>
<div>SAM-Z-1 Distributed Worker v4.0</div>
<div class="badge" id="status-badge">CHECKING STATUS...</div>
</div>
<div class="stats-grid" id="stats">
<div class="stat-card">
<div class="stat-label">Total Requests</div>
<div class="stat-value" id="total-req">--</div>
</div>
<div class="stat-card">
<div class="stat-label">Total Tokens</div>
<div class="stat-value" id="total-tokens">--</div>
</div>
<div class="stat-card">
<div class="stat-label">Decode Requests</div>
<div class="stat-value" id="decode-req">--</div>
</div>
<div class="stat-card">
<div class="stat-label">Uptime</div>
<div class="stat-value" id="uptime">--</div>
</div>
</div>
<div class="features">
<h3>πŸš€ CAPABILITIES</h3>
<ul class="feature-list">
<li>Full Text Generation</li>
<li>Token-Only Mode (for distributed pipeline)</li>
<li>High-Speed Batch Decoding</li>
<li>Chat Completion</li>
<li>Streaming & Non-Streaming</li>
</ul>
</div>
<div class="timestamp" id="timestamp">Initializing...</div>
</div>
<script>
async function updateStats() {
try {
const response = await fetch('/health');
const data = await response.json();
const badge = document.getElementById('status-badge');
if (data.model_loaded) {
badge.textContent = 'βœ… READY FOR INFERENCE';
badge.className = 'badge badge-ready';
} else {
badge.textContent = '⏳ LOADING MODEL...';
badge.className = 'badge badge-loading';
}
// Fetch stats
const statsRes = await fetch('/stats');
const stats = await statsRes.json();
document.getElementById('total-req').textContent = stats.total_requests;
document.getElementById('total-tokens').textContent = stats.total_tokens;
document.getElementById('decode-req').textContent = stats.decode_requests;
const uptime = Math.floor(stats.uptime);
const h = Math.floor(uptime / 3600);
const m = Math.floor((uptime % 3600) / 60);
const s = uptime % 60;
document.getElementById('uptime').textContent = `${h}h ${m}m ${s}s`;
document.getElementById('timestamp').textContent =
`Last update: ${new Date().toLocaleTimeString()}`;
} catch (e) {
console.error('Failed to update stats:', e);
}
}
// Update every second
setInterval(updateStats, 1000);
updateStats();
</script>
</body>
</html>
"""
# ============================================================================
# API Endpoints
# ============================================================================
@app.get("/health")
async def health():
return {
"status": "healthy" if model is not None else "loading",
"model_loaded": model is not None
}
@app.get("/stats")
async def stats():
uptime = time.time() - worker_stats["uptime_start"]
return {
"total_requests": worker_stats["total_requests"],
"total_tokens": worker_stats["total_tokens"],
"decode_requests": worker_stats["decode_requests"],
"uptime": uptime,
"tokens_per_second": worker_stats["total_tokens"] / uptime if uptime > 0 else 0
}
@app.post("/decode")
async def decode(request: DecodeRequest):
"""Fast single decode"""
if tokenizer is None:
raise HTTPException(status_code=503, detail="Tokenizer not loaded")
try:
worker_stats["decode_requests"] += 1
text = tokenizer.decode(request.token_ids)
return {"text": text}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Decode error: {str(e)}")
@app.post("/decode/batch")
async def batch_decode(request: BatchDecodeRequest):
"""Optimized batch decoding for distributed pipeline"""
if tokenizer is None:
raise HTTPException(status_code=503, detail="Tokenizer not loaded")
try:
worker_stats["decode_requests"] += len(request.batches)
results = [tokenizer.decode(batch) for batch in request.batches]
return {"texts": results}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch decode error: {str(e)}")
@app.post("/generate")
async def generate(request: GenerateRequest):
"""Generate text"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
worker_stats["total_requests"] += 1
start_time = time.time()
if request.stream:
async def stream_tokens():
generated_text = ""
token_count = 0
try:
for token_id, token_text in generate_tokens(
request.prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
return_token_ids=request.return_token_ids
):
token_count += 1
worker_stats["total_tokens"] += 1
if request.return_token_ids:
yield f"data: {json.dumps({'token_id': token_id})}\n\n"
else:
generated_text += token_text
yield f"data: {json.dumps({'text': token_text, 'total': generated_text})}\n\n"
await asyncio.sleep(0.001)
elapsed = time.time() - start_time
yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(stream_tokens(), media_type="text/event-stream")
else:
generated_text = ""
token_count = 0
try:
for token_id, token_text in generate_tokens(
request.prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
return_token_ids=request.return_token_ids
):
if not request.return_token_ids:
generated_text += token_text
token_count += 1
worker_stats["total_tokens"] += 1
elapsed = time.time() - start_time
return {
"text": generated_text,
"tokens": token_count,
"time": elapsed,
"tokens_per_second": token_count / elapsed if elapsed > 0 else 0
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
@app.post("/chat")
async def chat(request: ChatRequest):
"""Chat completion"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
worker_stats["total_requests"] += 1
prompt = format_chat_prompt(request.messages)
start_time = time.time()
if request.stream:
async def stream_tokens():
generated_text = ""
token_count = 0
try:
for token_id, token_text in generate_tokens(
prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
return_token_ids=request.return_token_ids
):
token_count += 1
worker_stats["total_tokens"] += 1
if request.return_token_ids:
yield f"data: {json.dumps({'token_id': token_id})}\n\n"
else:
generated_text += token_text
if "<|im_end|>" in generated_text:
generated_text = generated_text.split("<|im_end|>")[0]
break
yield f"data: {json.dumps({'delta': token_text, 'content': generated_text})}\n\n"
await asyncio.sleep(0.001)
elapsed = time.time() - start_time
yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(stream_tokens(), media_type="text/event-stream")
else:
generated_text = ""
token_count = 0
try:
for token_id, token_text in generate_tokens(
prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
return_token_ids=request.return_token_ids
):
if not request.return_token_ids:
generated_text += token_text
if "<|im_end|>" in generated_text:
generated_text = generated_text.split("<|im_end|>")[0]
break
token_count += 1
worker_stats["total_tokens"] += 1
elapsed = time.time() - start_time
return {
"message": {
"role": "assistant",
"content": generated_text.strip()
},
"tokens": token_count,
"time": elapsed,
"tokens_per_second": token_count / elapsed if elapsed > 0 else 0
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
# ============================================================================
# Model Loading
# ============================================================================
@app.on_event("startup")
async def load_model():
global model, tokenizer, config, eos_token_id, fast_forward
print("πŸš€ Loading SAM-Z-1 Model...")
try:
config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
try:
weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
print("βœ… Found checkpoint weights")
use_checkpoint = True
except:
print("⚠️ Checkpoint not found, using model.keras")
model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR)
use_checkpoint = False
with open(config_path, 'r') as f:
config = json.load(f)
print(f"πŸ“¦ Config loaded: {config['num_hidden_layers']} layers")
print("πŸ“¦ Creating tokenizer...")
from transformers import AutoTokenizer
hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"]
hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens})
os.makedirs("./temp_tokenizer", exist_ok=True)
hf_tokenizer.save_pretrained("./temp_tokenizer")
tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
eos_token_id = config.get('eos_token_id', 50256)
print(f"βœ… Tokenizer ready: vocab size {tokenizer.get_vocab_size()}")
print("πŸ”„ Loading model...")
if use_checkpoint:
model_config = {
'vocab_size': config['vocab_size'],
'd_model': config['hidden_size'],
'n_layers': config['num_hidden_layers'],
'n_heads': config['num_attention_heads'],
'ff_mult': config['intermediate_size'] / config['hidden_size'],
'max_len': config['max_position_embeddings'],
'dropout': 0.1,
'rope_theta': config['rope_theta']
}
model = SAM1Model(config=model_config)
dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
_ = model(dummy_input, training=False)
print(f"βœ… Architecture built: {model.count_params():,} parameters")
model.load_weights(weights_path)
print("βœ… Weights loaded!")
else:
model = keras.models.load_model(model_path, compile=False)
print("βœ… Model loaded!")
@tf.function(reduce_retracing=True)
def optimized_forward(input_tensor):
return model(input_tensor, training=False)
fast_forward = optimized_forward
print("βœ… SAM-Z-1 Distributed Worker ready! πŸš€")
print("πŸ”₯ Features enabled:")
print(" - Full text generation")
print(" - Token-only mode (distributed pipeline)")
print(" - Batch decoding optimization")
print(" - Streaming support")
except Exception as e:
print(f"❌ Failed to load model: {e}")
import traceback
traceback.print_exc()
raise
# ============================================================================
# Launch
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)