Keeby-smilyai's picture
Update app.py
7f368dd verified
raw
history blame
32.1 kB
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
import keras
import numpy as np
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import json
from abc import ABC, abstractmethod
import time
import threading
import hashlib
import sqlite3
from datetime import datetime, timedelta
import pytz
# ==============================================================================
# Performance Optimizations for CPU
# ==============================================================================
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.threading.set_intra_op_parallelism_threads(2)
tf.config.optimizer.set_jit(True)
tf.config.run_functions_eagerly(False)
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# Australian timezone
AUSTRALIA_TZ = pytz.timezone('Australia/Sydney')
# ==============================================================================
# Database Setup
# ==============================================================================
def init_database():
"""Initialize SQLite database for users and subscriptions."""
conn = sqlite3.connect('sam_users.db', check_same_thread=False)
c = conn.cursor()
# Users table
c.execute('''CREATE TABLE IF NOT EXISTS users
(id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
email TEXT,
plan TEXT DEFAULT 'free',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_admin BOOLEAN DEFAULT 0,
rate_limit_start TIMESTAMP,
messages_used_nano INTEGER DEFAULT 0,
messages_used_mini INTEGER DEFAULT 0,
messages_used_fast INTEGER DEFAULT 0,
messages_used_large INTEGER DEFAULT 0)''')
# Upgrade requests table
c.execute('''CREATE TABLE IF NOT EXISTS upgrade_requests
(id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
requested_plan TEXT,
reason TEXT,
status TEXT DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id))''')
# Usage tracking
c.execute('''CREATE TABLE IF NOT EXISTS usage_logs
(id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
tokens_used INTEGER,
model_used TEXT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id))''')
# Create admin account if not exists
admin_pass = hashlib.sha256("admin123".encode()).hexdigest()
try:
c.execute("INSERT INTO users (username, password_hash, email, plan, is_admin) VALUES (?, ?, ?, ?, ?)",
("admin", admin_pass, "[email protected]", "pro", 1))
conn.commit()
print("✅ Admin account created (username: admin, password: admin123)")
except sqlite3.IntegrityError:
print("✅ Admin account already exists")
conn.commit()
return conn
# Global database connection
db_conn = init_database()
db_lock = threading.Lock()
# Plan limits with 3-hour rolling window
PLAN_LIMITS = {
'free': {
'nano_messages': -1,
'mini_messages': -1,
'fast_messages': 10,
'large_messages': 8,
'can_choose_model': False,
'max_tokens': 256,
'reset_hours': 3
},
'plus': {
'nano_messages': -1,
'mini_messages': -1,
'fast_messages': -1,
'large_messages': 20,
'can_choose_model': True,
'max_tokens': 384,
'reset_hours': 3
},
'pro': {
'nano_messages': -1,
'mini_messages': -1,
'fast_messages': -1,
'large_messages': -1,
'can_choose_model': True,
'max_tokens': 512,
'reset_hours': 3
}
}
def get_model_type(model_name):
"""Get model type from model name."""
if 'Nano' in model_name:
return 'nano'
elif 'Mini' in model_name:
return 'mini'
elif 'Fast' in model_name:
return 'fast'
elif 'Large' in model_name:
return 'large'
return 'nano'
# ==============================================================================
# User Management Functions
# ==============================================================================
def hash_password(password):
return hashlib.sha256(password.encode()).hexdigest()
def create_user(username, password, email=""):
with db_lock:
try:
c = db_conn.cursor()
now = datetime.now(AUSTRALIA_TZ).isoformat()
c.execute("INSERT INTO users (username, password_hash, email, rate_limit_start) VALUES (?, ?, ?, ?)",
(username, hash_password(password), email, now))
db_conn.commit()
return True, "Account created successfully!"
except sqlite3.IntegrityError:
return False, "Username already exists!"
def authenticate_user(username, password):
with db_lock:
c = db_conn.cursor()
c.execute("SELECT id, password_hash, plan, is_admin FROM users WHERE username = ?", (username,))
result = c.fetchone()
if result and result[1] == hash_password(password):
return True, {"id": result[0], "username": username, "plan": result[2], "is_admin": bool(result[3])}
return False, None
def check_and_reset_limits(user_id):
"""Check if 3-hour window has passed and reset limits if needed."""
with db_lock:
c = db_conn.cursor()
c.execute("SELECT rate_limit_start, plan FROM users WHERE id = ?", (user_id,))
result = c.fetchone()
if not result:
return
rate_limit_start_str, plan = result
reset_hours = PLAN_LIMITS[plan]['reset_hours']
if rate_limit_start_str:
rate_limit_start = datetime.fromisoformat(rate_limit_start_str)
now = datetime.now(AUSTRALIA_TZ)
if now - rate_limit_start >= timedelta(hours=reset_hours):
new_start = now.isoformat()
c.execute("""UPDATE users
SET rate_limit_start = ?,
messages_used_nano = 0,
messages_used_mini = 0,
messages_used_fast = 0,
messages_used_large = 0
WHERE id = ?""", (new_start, user_id))
db_conn.commit()
def get_user_limits_info(user_id):
"""Get user's current usage and limits with reset time."""
check_and_reset_limits(user_id)
with db_lock:
c = db_conn.cursor()
c.execute("""SELECT plan, rate_limit_start,
messages_used_nano, messages_used_mini,
messages_used_fast, messages_used_large
FROM users WHERE id = ?""", (user_id,))
result = c.fetchone()
if not result:
return None
plan, rate_limit_start_str, nano_used, mini_used, fast_used, large_used = result
limits = PLAN_LIMITS[plan]
if rate_limit_start_str:
rate_limit_start = datetime.fromisoformat(rate_limit_start_str)
reset_time = rate_limit_start + timedelta(hours=limits['reset_hours'])
now = datetime.now(AUSTRALIA_TZ)
time_until_reset = reset_time - now
hours, remainder = divmod(int(time_until_reset.total_seconds()), 3600)
minutes, seconds = divmod(remainder, 60)
reset_str = f"{hours}h {minutes}m"
else:
reset_str = "N/A"
return {
'plan': plan,
'nano_used': nano_used,
'mini_used': mini_used,
'fast_used': fast_used,
'large_used': large_used,
'nano_limit': limits['nano_messages'],
'mini_limit': limits['mini_messages'],
'fast_limit': limits['fast_messages'],
'large_limit': limits['large_messages'],
'can_choose_model': limits['can_choose_model'],
'max_tokens': limits['max_tokens'],
'reset_in': reset_str
}
def can_use_model(user_id, model_name):
"""Check if user can use a specific model."""
info = get_user_limits_info(user_id)
if not info:
return False, "User not found"
model_type = get_model_type(model_name)
used_key = f"{model_type}_used"
limit_key = f"{model_type}_limit"
used = info[used_key]
limit = info[limit_key]
if limit == -1:
return True, "OK"
if used >= limit:
return False, f"Limit reached for {model_type.upper()} model ({used}/{limit}). Resets in {info['reset_in']}"
return True, "OK"
def increment_model_usage(user_id, model_name):
"""Increment usage counter for a model."""
model_type = get_model_type(model_name)
column = f"messages_used_{model_type}"
with db_lock:
c = db_conn.cursor()
c.execute(f"UPDATE users SET {column} = {column} + 1 WHERE id = ?", (user_id,))
db_conn.commit()
def get_available_models_for_user(user_id):
"""Get list of models user can currently use."""
info = get_user_limits_info(user_id)
if not info:
return []
available = []
for model_type in ['nano', 'mini', 'fast', 'large']:
used = info[f'{model_type}_used']
limit = info[f'{model_type}_limit']
if limit == -1 or used < limit:
for model_name in available_models.keys():
if get_model_type(model_name) == model_type:
available.append(model_name)
break
return available
def log_usage(user_id, tokens, model):
with db_lock:
c = db_conn.cursor()
c.execute("INSERT INTO usage_logs (user_id, tokens_used, model_used) VALUES (?, ?, ?)",
(user_id, tokens, model))
db_conn.commit()
def request_upgrade(user_id, plan, reason):
with db_lock:
try:
c = db_conn.cursor()
c.execute("INSERT INTO upgrade_requests (user_id, requested_plan, reason) VALUES (?, ?, ?)",
(user_id, plan, reason))
db_conn.commit()
return True, "Upgrade request submitted! Admin will review soon."
except Exception as e:
return False, f"Error: {str(e)}"
def get_all_users():
with db_lock:
c = db_conn.cursor()
c.execute("""SELECT id, username, email, plan, created_at, is_admin,
messages_used_nano, messages_used_mini,
messages_used_fast, messages_used_large,
rate_limit_start
FROM users ORDER BY created_at DESC""")
return c.fetchall()
def get_pending_requests():
with db_lock:
c = db_conn.cursor()
c.execute("""SELECT r.id, u.username, r.requested_plan, r.reason, r.created_at
FROM upgrade_requests r
JOIN users u ON r.user_id = u.id
WHERE r.status = 'pending'
ORDER BY r.created_at DESC""")
return c.fetchall()
def update_user_plan(username, new_plan):
with db_lock:
try:
c = db_conn.cursor()
now = datetime.now(AUSTRALIA_TZ).isoformat()
c.execute("""UPDATE users
SET plan = ?,
rate_limit_start = ?,
messages_used_nano = 0,
messages_used_mini = 0,
messages_used_fast = 0,
messages_used_large = 0
WHERE username = ?""", (new_plan, now, username))
db_conn.commit()
return True, f"User {username} upgraded to {new_plan}!"
except Exception as e:
return False, f"Error: {str(e)}"
def approve_request(request_id):
with db_lock:
try:
c = db_conn.cursor()
c.execute("SELECT user_id, requested_plan FROM upgrade_requests WHERE id = ?", (request_id,))
result = c.fetchone()
if result:
user_id, plan = result
now = datetime.now(AUSTRALIA_TZ).isoformat()
c.execute("""UPDATE users
SET plan = ?,
rate_limit_start = ?,
messages_used_nano = 0,
messages_used_mini = 0,
messages_used_fast = 0,
messages_used_large = 0
WHERE id = ?""", (plan, now, user_id))
c.execute("UPDATE upgrade_requests SET status = 'approved' WHERE id = ?", (request_id,))
db_conn.commit()
return True, "Request approved!"
return False, "Request not found"
except Exception as e:
return False, f"Error: {str(e)}"
def deny_request(request_id):
with db_lock:
try:
c = db_conn.cursor()
c.execute("UPDATE upgrade_requests SET status = 'denied' WHERE id = ?", (request_id,))
db_conn.commit()
return True, "Request denied"
except Exception as e:
return False, f"Error: {str(e)}"
# ==============================================================================
# Model Architecture
# ==============================================================================
@keras.saving.register_keras_serializable()
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_len = max_len
self.theta = theta
self.built_cache = False
def build(self, input_shape):
if not self.built_cache:
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
t = tf.range(self.max_len, dtype=tf.float32)
freqs = tf.einsum("i,j->ij", t, inv_freq)
emb = tf.concat([freqs, freqs], axis=-1)
self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
self.built_cache = True
super().build(input_shape)
def rotate_half(self, x):
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat([-x2, x1], axis=-1)
def call(self, q, k):
seq_len = tf.shape(q)[2]
dtype = q.dtype
cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
q_rotated = (q * cos) + (self.rotate_half(q) * sin)
k_rotated = (k * cos) + (self.rotate_half(k) * sin)
return q_rotated, k_rotated
def get_config(self):
config = super().get_config()
config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
return config
@keras.saving.register_keras_serializable()
class RMSNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
def call(self, x):
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
def get_config(self):
config = super().get_config()
config.update({"epsilon": self.epsilon})
return config
@keras.saving.register_keras_serializable()
class TransformerBlock(keras.layers.Layer):
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout
self.max_len = max_len
self.rope_theta = rope_theta
self.head_dim = d_model // n_heads
self.layer_idx = layer_idx
self.pre_attn_norm = RMSNorm()
self.pre_ffn_norm = RMSNorm()
self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
self.dropout = keras.layers.Dropout(dropout)
def call(self, x, training=None):
B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
dtype = x.dtype
res = x
y = self.pre_attn_norm(x)
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
q, k = self.rope(q, k)
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
scores += mask
attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
x = res + self.dropout(self.out_proj(attn), training=training)
res = x
y = self.pre_ffn_norm(x)
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
return res + self.dropout(ffn, training=training)
def get_config(self):
config = super().get_config()
config.update({"d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta":
# PART 2 - Continue from Part 1
self.rope_theta, "layer_idx": self.layer_idx})
return config
@keras.saving.register_keras_serializable()
class SAM1Model(keras.Model):
def __init__(self, **kwargs):
super().__init__()
if 'config' in kwargs and isinstance(kwargs['config'], dict):
self.cfg = kwargs['config']
elif 'vocab_size' in kwargs:
self.cfg = kwargs
else:
self.cfg = kwargs.get('cfg', kwargs)
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
block_args = {'d_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']}
self.blocks = []
for i in range(self.cfg['n_layers']):
block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
self.blocks.append(block)
self.norm = RMSNorm(name="final_norm")
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
def call(self, input_ids, training=None):
x = self.embed(input_ids)
for block in self.blocks:
x = block(x, training=training)
return self.lm_head(self.norm(x))
def get_config(self):
base_config = super().get_config()
base_config['config'] = self.cfg
return base_config
def count_parameters(model):
total_params = 0
non_zero_params = 0
for weight in model.weights:
w = weight.numpy()
total_params += w.size
non_zero_params += np.count_nonzero(w)
return total_params, non_zero_params
def format_param_count(count):
if count >= 1e9:
return f"{count/1e9:.2f}B"
elif count >= 1e6:
return f"{count/1e6:.2f}M"
elif count >= 1e3:
return f"{count/1e3:.2f}K"
else:
return str(count)
class ModelBackend(ABC):
@abstractmethod
def predict(self, input_ids):
pass
@abstractmethod
def get_name(self):
pass
@abstractmethod
def get_info(self):
pass
class KerasBackend(ModelBackend):
def __init__(self, model, name, display_name):
self.model = model
self.name = name
self.display_name = display_name
@tf.function(input_signature=[tf.TensorSpec(shape=[1, None], dtype=tf.int32)], jit_compile=True)
def fast_predict(inputs):
return model(inputs, training=False)
self.fast_predict = fast_predict
print(f" 🔥 Warming up {display_name}...")
dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
_ = self.fast_predict(dummy)
print(f" ✅ Compilation complete!")
total, non_zero = count_parameters(model)
self.total_params = total
self.non_zero_params = non_zero
self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0
self.n_heads = model.cfg.get('n_heads', 0)
self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
def predict(self, input_ids):
inputs = tf.constant([input_ids], dtype=tf.int32)
logits = self.fast_predict(inputs)
return logits[0, -1, :].numpy()
def get_name(self):
return self.display_name
def get_info(self):
info = f"{self.display_name}\n"
info += f" Total params: {format_param_count(self.total_params)}\n"
info += f" Attention heads: {self.n_heads}\n"
info += f" FFN dimension: {self.ff_dim}\n"
if self.sparsity > 1:
info += f" Sparsity: {self.sparsity:.1f}%\n"
return info
MODEL_REGISTRY = [
("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"),
("SAM-X-1-Mini 🚀 (ADVANCED!)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini_finetuned.weights.h5", "sam1_mini_finetuned_config.json"),
("SAM-X-1-Nano ⚡⚡", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano_finetuned.weights.h5", "sam1_nano_finetuned_config.json"),
]
def estimate_prompt_complexity(prompt):
prompt_lower = prompt.lower()
complexity_score = 0
word_count = len(prompt.split())
if word_count > 100:
complexity_score += 3
elif word_count > 50:
complexity_score += 2
elif word_count > 20:
complexity_score += 1
hard_keywords = ['analyze', 'explain', 'compare', 'evaluate', 'prove', 'derive', 'calculate', 'solve', 'reason', 'why', 'how does', 'complex', 'algorithm', 'mathematics', 'philosophy', 'theory', 'logic', 'detailed', 'comprehensive', 'thorough', 'in-depth']
for keyword in hard_keywords:
if keyword in prompt_lower:
complexity_score += 2
medium_keywords = ['write', 'create', 'generate', 'summarize', 'describe', 'list', 'what is', 'tell me', 'explain briefly']
for keyword in medium_keywords:
if keyword in prompt_lower:
complexity_score += 1
if any(word in prompt_lower for word in ['code', 'function', 'program', 'debug', 'implement']):
complexity_score += 2
if any(word in prompt_lower for word in ['first', 'then', 'next', 'finally', 'step']):
complexity_score += 1
question_marks = prompt.count('?')
if question_marks > 1:
complexity_score += 1
return complexity_score
def select_model_auto(prompt, available_models_dict, user_available_models):
complexity = estimate_prompt_complexity(prompt)
accessible = {k: v for k, v in available_models_dict.items() if k in user_available_models}
if not accessible:
return None
if complexity <= 2:
preferred = "SAM-X-1-Nano ⚡⚡"
fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"]
elif complexity <= 5:
preferred = "SAM-X-1-Mini 🚀 (ADVANCED!)"
fallback_order = ["SAM-X-1-Nano ⚡⚡", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"]
elif complexity <= 8:
preferred = "SAM-X-1-Fast ⚡ (BETA)"
fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Large", "SAM-X-1-Nano ⚡⚡"]
else:
preferred = "SAM-X-1-Large"
fallback_order = ["SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Nano ⚡⚡"]
if preferred in accessible:
return accessible[preferred]
for model_name in fallback_order:
if model_name in accessible:
return accessible[model_name]
return list(accessible.values())[0]
CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
print("="*80)
print("🤖 SAM-X-1 Multi-Model Chat Interface".center(80))
print("="*80)
print(f"\n📦 Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}")
config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json")
with open(config_path, 'r') as f:
base_config = json.load(f)
print(f"✅ Base config loaded")
base_model_config = {'vocab_size': base_config['vocab_size'], 'd_model': base_config['hidden_size'], 'n_heads': base_config['num_attention_heads'], 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'], 'dropout': base_config.get('dropout', 0.0), 'max_len': base_config['max_position_embeddings'], 'rope_theta': base_config['rope_theta'], 'n_layers': base_config['num_hidden_layers']}
print("\n🔤 Recreating tokenizer...")
tokenizer = Tokenizer.from_pretrained("gpt2")
eos_token = "<|endoftext|>"
eos_token_id = tokenizer.token_to_id(eos_token)
if eos_token_id is None:
tokenizer.add_special_tokens([eos_token])
eos_token_id = tokenizer.token_to_id(eos_token)
custom_tokens = ["<think>", "<think/>"]
for token in custom_tokens:
if tokenizer.token_to_id(token) is None:
tokenizer.add_special_tokens([token])
tokenizer.no_padding()
tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
print(f"✅ Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})")
print(f" EOS token: '{eos_token}' (ID: {eos_token_id})")
if eos_token_id is None:
raise ValueError("❌ Failed to set EOS token ID!")
print("\n" + "="*80)
print("📦 LOADING MODELS".center(80))
print("="*80)
available_models = {}
dummy_input = tf.zeros((1, 1), dtype=tf.int32)
for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
try:
print(f"\n⏳ Loading: {display_name}")
print(f" Repo: {repo_id}")
print(f" Weights: {weights_filename}")
weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
if config_filename:
print(f" Config: {config_filename}")
custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
with open(custom_config_path, 'r') as f:
model_config = json.load(f)
print(f" 📐 Custom architecture: {model_config['n_heads']} heads")
else:
model_config = base_model_config.copy()
model = SAM1Model(**model_config)
model(dummy_input)
model.load_weights(weights_path)
model.trainable = False
backend = KerasBackend(model, display_name, display_name)
available_models[display_name] = backend
print(f" ✅ Loaded successfully!")
print(f" 📊 Parameters: {format_param_count(backend.total_params)}")
except Exception as e:
print(f" ⚠️ Failed to load: {e}")
if not available_models:
raise RuntimeError("❌ No models loaded!")
print(f"\n✅ Successfully loaded {len(available_models)} model(s)")
current_backend = list(available_models.values())[0]
stop_generation = threading.Event()
def generate_response_stream(prompt, temperature=0.7, backend=None, max_tokens=256):
global stop_generation
stop_generation.clear()
if backend is None:
backend = current_backend
encoded_prompt = tokenizer.encode(prompt)
input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
generated = input_ids.copy()
current_text = ""
in_thinking = False
max_len = backend.model.cfg['max_len']
start_time = time.time()
tokens_generated = 0
decode_buffer = []
decode_every = 2
last_speed_check = start_time
for step in range(max_tokens):
if stop_generation.is_set():
elapsed = time.time() - start_time
final_speed = tokens_generated / elapsed if elapsed > 0 else 0
yield "", False, -1, final_speed, True
return
current_input = generated[-max_len:]
next_token_logits = backend.predict(current_input)
if tokens_generated > 5 and tokens_generated % 10 == 0:
current_time = time.time()
elapsed_since_check = current_time - last_speed_check
if elapsed_since_check > 0:
recent_speed = 10 / elapsed_since_check
if recent_speed > 25:
decode_every = 8
elif recent_speed > 15:
decode_every = 5
elif recent_speed > 8:
decode_every = 3
else:
decode_every = 2
last_speed_check = current_time
if temperature > 0:
next_token_logits = next_token_logits / temperature
top_k = 5
top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
top_k_logits = next_token_logits[top_k_indices]
max_logit = np.max(top_k_logits)
exp_logits = np.exp(top_k_logits - max_logit)
probs = exp_logits / np.sum(exp_logits)
next_token = top_k_indices[np.random.choice(top_k, p=probs)]
else:
next_token = np.argmax(next_token_logits)
if next_token == eos_token_id:
break
generated.append(int(next_token))
decode_buffer.append(int(next_token))
tokens_generated += 1
should_decode = (len(decode_buffer) >= decode_every or step == max_tokens - 1)
if should_decode:
new_text = tokenizer.decode(generated[len(input_ids):])
if len(new_text) > len(current_text):
new_chunk = new_text[len(current_text):]
current_text = new_text
if "<think>" in new_chunk:
in_thinking = True
elif "</think>" in new_chunk or "<think/>" in new_chunk:
in_thinking = False
elapsed = time.time() - start_time
tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0
yield new_chunk, in_thinking, tokens_per_sec, tokens_per_sec, False
decode_buffer = []
elapsed = time.time() - start_time
final_tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0
yield "", False, final_tokens_per_sec, final_tokens_per_sec, False