Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['KERAS_BACKEND'] = 'tensorflow' | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| import tensorflow as tf | |
| import keras | |
| import numpy as np | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| from abc import ABC, abstractmethod | |
| import time | |
| import threading | |
| import hashlib | |
| import sqlite3 | |
| from datetime import datetime, timedelta | |
| import pytz | |
| # ============================================================================== | |
| # Performance Optimizations for CPU | |
| # ============================================================================== | |
| tf.config.threading.set_inter_op_parallelism_threads(1) | |
| tf.config.threading.set_intra_op_parallelism_threads(2) | |
| tf.config.optimizer.set_jit(True) | |
| tf.config.run_functions_eagerly(False) | |
| os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async' | |
| os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' | |
| # Australian timezone | |
| AUSTRALIA_TZ = pytz.timezone('Australia/Sydney') | |
| # ============================================================================== | |
| # Database Setup | |
| # ============================================================================== | |
| def init_database(): | |
| """Initialize SQLite database for users and subscriptions.""" | |
| conn = sqlite3.connect('sam_users.db', check_same_thread=False) | |
| c = conn.cursor() | |
| # Users table | |
| c.execute('''CREATE TABLE IF NOT EXISTS users | |
| (id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT UNIQUE NOT NULL, | |
| password_hash TEXT NOT NULL, | |
| email TEXT, | |
| plan TEXT DEFAULT 'free', | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| is_admin BOOLEAN DEFAULT 0, | |
| rate_limit_start TIMESTAMP, | |
| messages_used_nano INTEGER DEFAULT 0, | |
| messages_used_mini INTEGER DEFAULT 0, | |
| messages_used_fast INTEGER DEFAULT 0, | |
| messages_used_large INTEGER DEFAULT 0)''') | |
| # Upgrade requests table | |
| c.execute('''CREATE TABLE IF NOT EXISTS upgrade_requests | |
| (id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER, | |
| requested_plan TEXT, | |
| reason TEXT, | |
| status TEXT DEFAULT 'pending', | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (user_id) REFERENCES users(id))''') | |
| # Usage tracking | |
| c.execute('''CREATE TABLE IF NOT EXISTS usage_logs | |
| (id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER, | |
| tokens_used INTEGER, | |
| model_used TEXT, | |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (user_id) REFERENCES users(id))''') | |
| # Create admin account if not exists | |
| admin_pass = hashlib.sha256("admin123".encode()).hexdigest() | |
| try: | |
| c.execute("INSERT INTO users (username, password_hash, email, plan, is_admin) VALUES (?, ?, ?, ?, ?)", | |
| ("admin", admin_pass, "[email protected]", "pro", 1)) | |
| conn.commit() | |
| print("✅ Admin account created (username: admin, password: admin123)") | |
| except sqlite3.IntegrityError: | |
| print("✅ Admin account already exists") | |
| conn.commit() | |
| return conn | |
| # Global database connection | |
| db_conn = init_database() | |
| db_lock = threading.Lock() | |
| # Plan limits with 3-hour rolling window | |
| PLAN_LIMITS = { | |
| 'free': { | |
| 'nano_messages': -1, | |
| 'mini_messages': -1, | |
| 'fast_messages': 10, | |
| 'large_messages': 8, | |
| 'can_choose_model': False, | |
| 'max_tokens': 256, | |
| 'reset_hours': 3 | |
| }, | |
| 'plus': { | |
| 'nano_messages': -1, | |
| 'mini_messages': -1, | |
| 'fast_messages': -1, | |
| 'large_messages': 20, | |
| 'can_choose_model': True, | |
| 'max_tokens': 384, | |
| 'reset_hours': 3 | |
| }, | |
| 'pro': { | |
| 'nano_messages': -1, | |
| 'mini_messages': -1, | |
| 'fast_messages': -1, | |
| 'large_messages': -1, | |
| 'can_choose_model': True, | |
| 'max_tokens': 512, | |
| 'reset_hours': 3 | |
| } | |
| } | |
| def get_model_type(model_name): | |
| """Get model type from model name.""" | |
| if 'Nano' in model_name: | |
| return 'nano' | |
| elif 'Mini' in model_name: | |
| return 'mini' | |
| elif 'Fast' in model_name: | |
| return 'fast' | |
| elif 'Large' in model_name: | |
| return 'large' | |
| return 'nano' | |
| # ============================================================================== | |
| # User Management Functions | |
| # ============================================================================== | |
| def hash_password(password): | |
| return hashlib.sha256(password.encode()).hexdigest() | |
| def create_user(username, password, email=""): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| now = datetime.now(AUSTRALIA_TZ).isoformat() | |
| c.execute("INSERT INTO users (username, password_hash, email, rate_limit_start) VALUES (?, ?, ?, ?)", | |
| (username, hash_password(password), email, now)) | |
| db_conn.commit() | |
| return True, "Account created successfully!" | |
| except sqlite3.IntegrityError: | |
| return False, "Username already exists!" | |
| def authenticate_user(username, password): | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("SELECT id, password_hash, plan, is_admin FROM users WHERE username = ?", (username,)) | |
| result = c.fetchone() | |
| if result and result[1] == hash_password(password): | |
| return True, {"id": result[0], "username": username, "plan": result[2], "is_admin": bool(result[3])} | |
| return False, None | |
| def check_and_reset_limits(user_id): | |
| """Check if 3-hour window has passed and reset limits if needed.""" | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("SELECT rate_limit_start, plan FROM users WHERE id = ?", (user_id,)) | |
| result = c.fetchone() | |
| if not result: | |
| return | |
| rate_limit_start_str, plan = result | |
| reset_hours = PLAN_LIMITS[plan]['reset_hours'] | |
| if rate_limit_start_str: | |
| rate_limit_start = datetime.fromisoformat(rate_limit_start_str) | |
| now = datetime.now(AUSTRALIA_TZ) | |
| if now - rate_limit_start >= timedelta(hours=reset_hours): | |
| new_start = now.isoformat() | |
| c.execute("""UPDATE users | |
| SET rate_limit_start = ?, | |
| messages_used_nano = 0, | |
| messages_used_mini = 0, | |
| messages_used_fast = 0, | |
| messages_used_large = 0 | |
| WHERE id = ?""", (new_start, user_id)) | |
| db_conn.commit() | |
| def get_user_limits_info(user_id): | |
| """Get user's current usage and limits with reset time.""" | |
| check_and_reset_limits(user_id) | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("""SELECT plan, rate_limit_start, | |
| messages_used_nano, messages_used_mini, | |
| messages_used_fast, messages_used_large | |
| FROM users WHERE id = ?""", (user_id,)) | |
| result = c.fetchone() | |
| if not result: | |
| return None | |
| plan, rate_limit_start_str, nano_used, mini_used, fast_used, large_used = result | |
| limits = PLAN_LIMITS[plan] | |
| if rate_limit_start_str: | |
| rate_limit_start = datetime.fromisoformat(rate_limit_start_str) | |
| reset_time = rate_limit_start + timedelta(hours=limits['reset_hours']) | |
| now = datetime.now(AUSTRALIA_TZ) | |
| time_until_reset = reset_time - now | |
| hours, remainder = divmod(int(time_until_reset.total_seconds()), 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| reset_str = f"{hours}h {minutes}m" | |
| else: | |
| reset_str = "N/A" | |
| return { | |
| 'plan': plan, | |
| 'nano_used': nano_used, | |
| 'mini_used': mini_used, | |
| 'fast_used': fast_used, | |
| 'large_used': large_used, | |
| 'nano_limit': limits['nano_messages'], | |
| 'mini_limit': limits['mini_messages'], | |
| 'fast_limit': limits['fast_messages'], | |
| 'large_limit': limits['large_messages'], | |
| 'can_choose_model': limits['can_choose_model'], | |
| 'max_tokens': limits['max_tokens'], | |
| 'reset_in': reset_str | |
| } | |
| def can_use_model(user_id, model_name): | |
| """Check if user can use a specific model.""" | |
| info = get_user_limits_info(user_id) | |
| if not info: | |
| return False, "User not found" | |
| model_type = get_model_type(model_name) | |
| used_key = f"{model_type}_used" | |
| limit_key = f"{model_type}_limit" | |
| used = info[used_key] | |
| limit = info[limit_key] | |
| if limit == -1: | |
| return True, "OK" | |
| if used >= limit: | |
| return False, f"Limit reached for {model_type.upper()} model ({used}/{limit}). Resets in {info['reset_in']}" | |
| return True, "OK" | |
| def increment_model_usage(user_id, model_name): | |
| """Increment usage counter for a model.""" | |
| model_type = get_model_type(model_name) | |
| column = f"messages_used_{model_type}" | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute(f"UPDATE users SET {column} = {column} + 1 WHERE id = ?", (user_id,)) | |
| db_conn.commit() | |
| def get_available_models_for_user(user_id): | |
| """Get list of models user can currently use.""" | |
| info = get_user_limits_info(user_id) | |
| if not info: | |
| return [] | |
| available = [] | |
| for model_type in ['nano', 'mini', 'fast', 'large']: | |
| used = info[f'{model_type}_used'] | |
| limit = info[f'{model_type}_limit'] | |
| if limit == -1 or used < limit: | |
| for model_name in available_models.keys(): | |
| if get_model_type(model_name) == model_type: | |
| available.append(model_name) | |
| break | |
| return available | |
| def log_usage(user_id, tokens, model): | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("INSERT INTO usage_logs (user_id, tokens_used, model_used) VALUES (?, ?, ?)", | |
| (user_id, tokens, model)) | |
| db_conn.commit() | |
| def request_upgrade(user_id, plan, reason): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| c.execute("INSERT INTO upgrade_requests (user_id, requested_plan, reason) VALUES (?, ?, ?)", | |
| (user_id, plan, reason)) | |
| db_conn.commit() | |
| return True, "Upgrade request submitted! Admin will review soon." | |
| except Exception as e: | |
| return False, f"Error: {str(e)}" | |
| def get_all_users(): | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("""SELECT id, username, email, plan, created_at, is_admin, | |
| messages_used_nano, messages_used_mini, | |
| messages_used_fast, messages_used_large, | |
| rate_limit_start | |
| FROM users ORDER BY created_at DESC""") | |
| return c.fetchall() | |
| def get_pending_requests(): | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("""SELECT r.id, u.username, r.requested_plan, r.reason, r.created_at | |
| FROM upgrade_requests r | |
| JOIN users u ON r.user_id = u.id | |
| WHERE r.status = 'pending' | |
| ORDER BY r.created_at DESC""") | |
| return c.fetchall() | |
| def update_user_plan(username, new_plan): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| now = datetime.now(AUSTRALIA_TZ).isoformat() | |
| c.execute("""UPDATE users | |
| SET plan = ?, | |
| rate_limit_start = ?, | |
| messages_used_nano = 0, | |
| messages_used_mini = 0, | |
| messages_used_fast = 0, | |
| messages_used_large = 0 | |
| WHERE username = ?""", (new_plan, now, username)) | |
| db_conn.commit() | |
| return True, f"User {username} upgraded to {new_plan}!" | |
| except Exception as e: | |
| return False, f"Error: {str(e)}" | |
| def approve_request(request_id): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| c.execute("SELECT user_id, requested_plan FROM upgrade_requests WHERE id = ?", (request_id,)) | |
| result = c.fetchone() | |
| if result: | |
| user_id, plan = result | |
| now = datetime.now(AUSTRALIA_TZ).isoformat() | |
| c.execute("""UPDATE users | |
| SET plan = ?, | |
| rate_limit_start = ?, | |
| messages_used_nano = 0, | |
| messages_used_mini = 0, | |
| messages_used_fast = 0, | |
| messages_used_large = 0 | |
| WHERE id = ?""", (plan, now, user_id)) | |
| c.execute("UPDATE upgrade_requests SET status = 'approved' WHERE id = ?", (request_id,)) | |
| db_conn.commit() | |
| return True, "Request approved!" | |
| return False, "Request not found" | |
| except Exception as e: | |
| return False, f"Error: {str(e)}" | |
| def deny_request(request_id): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| c.execute("UPDATE upgrade_requests SET status = 'denied' WHERE id = ?", (request_id,)) | |
| db_conn.commit() | |
| return True, "Request denied" | |
| except Exception as e: | |
| return False, f"Error: {str(e)}" | |
| # ============================================================================== | |
| # Model Architecture | |
| # ============================================================================== | |
| class RotaryEmbedding(keras.layers.Layer): | |
| def __init__(self, dim, max_len=2048, theta=10000, **kwargs): | |
| super().__init__(**kwargs) | |
| self.dim = dim | |
| self.max_len = max_len | |
| self.theta = theta | |
| self.built_cache = False | |
| def build(self, input_shape): | |
| if not self.built_cache: | |
| inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) | |
| t = tf.range(self.max_len, dtype=tf.float32) | |
| freqs = tf.einsum("i,j->ij", t, inv_freq) | |
| emb = tf.concat([freqs, freqs], axis=-1) | |
| self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32) | |
| self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32) | |
| self.built_cache = True | |
| super().build(input_shape) | |
| def rotate_half(self, x): | |
| x1, x2 = tf.split(x, 2, axis=-1) | |
| return tf.concat([-x2, x1], axis=-1) | |
| def call(self, q, k): | |
| seq_len = tf.shape(q)[2] | |
| dtype = q.dtype | |
| cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :] | |
| sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :] | |
| q_rotated = (q * cos) + (self.rotate_half(q) * sin) | |
| k_rotated = (k * cos) + (self.rotate_half(k) * sin) | |
| return q_rotated, k_rotated | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) | |
| return config | |
| class RMSNorm(keras.layers.Layer): | |
| def __init__(self, epsilon=1e-5, **kwargs): | |
| super().__init__(**kwargs) | |
| self.epsilon = epsilon | |
| def build(self, input_shape): | |
| self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") | |
| def call(self, x): | |
| variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) | |
| return x * tf.math.rsqrt(variance + self.epsilon) * self.scale | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"epsilon": self.epsilon}) | |
| return config | |
| class TransformerBlock(keras.layers.Layer): | |
| def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): | |
| super().__init__(**kwargs) | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.ff_dim = ff_dim | |
| self.dropout_rate = dropout | |
| self.max_len = max_len | |
| self.rope_theta = rope_theta | |
| self.head_dim = d_model // n_heads | |
| self.layer_idx = layer_idx | |
| self.pre_attn_norm = RMSNorm() | |
| self.pre_ffn_norm = RMSNorm() | |
| self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj") | |
| self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj") | |
| self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj") | |
| self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj") | |
| self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta) | |
| self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj") | |
| self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj") | |
| self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj") | |
| self.dropout = keras.layers.Dropout(dropout) | |
| def call(self, x, training=None): | |
| B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model | |
| dtype = x.dtype | |
| res = x | |
| y = self.pre_attn_norm(x) | |
| q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| q, k = self.rope(q, k) | |
| scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) | |
| mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype)) | |
| scores += mask | |
| attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v) | |
| attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D]) | |
| x = res + self.dropout(self.out_proj(attn), training=training) | |
| res = x | |
| y = self.pre_ffn_norm(x) | |
| ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) | |
| return res + self.dropout(ffn, training=training) | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": | |
| # PART 2 - Continue from Part 1 | |
| self.rope_theta, "layer_idx": self.layer_idx}) | |
| return config | |
| class SAM1Model(keras.Model): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| if 'config' in kwargs and isinstance(kwargs['config'], dict): | |
| self.cfg = kwargs['config'] | |
| elif 'vocab_size' in kwargs: | |
| self.cfg = kwargs | |
| else: | |
| self.cfg = kwargs.get('cfg', kwargs) | |
| self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") | |
| ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) | |
| block_args = {'d_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']} | |
| self.blocks = [] | |
| for i in range(self.cfg['n_layers']): | |
| block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) | |
| self.blocks.append(block) | |
| self.norm = RMSNorm(name="final_norm") | |
| self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") | |
| def call(self, input_ids, training=None): | |
| x = self.embed(input_ids) | |
| for block in self.blocks: | |
| x = block(x, training=training) | |
| return self.lm_head(self.norm(x)) | |
| def get_config(self): | |
| base_config = super().get_config() | |
| base_config['config'] = self.cfg | |
| return base_config | |
| def count_parameters(model): | |
| total_params = 0 | |
| non_zero_params = 0 | |
| for weight in model.weights: | |
| w = weight.numpy() | |
| total_params += w.size | |
| non_zero_params += np.count_nonzero(w) | |
| return total_params, non_zero_params | |
| def format_param_count(count): | |
| if count >= 1e9: | |
| return f"{count/1e9:.2f}B" | |
| elif count >= 1e6: | |
| return f"{count/1e6:.2f}M" | |
| elif count >= 1e3: | |
| return f"{count/1e3:.2f}K" | |
| else: | |
| return str(count) | |
| class ModelBackend(ABC): | |
| def predict(self, input_ids): | |
| pass | |
| def get_name(self): | |
| pass | |
| def get_info(self): | |
| pass | |
| class KerasBackend(ModelBackend): | |
| def __init__(self, model, name, display_name): | |
| self.model = model | |
| self.name = name | |
| self.display_name = display_name | |
| def fast_predict(inputs): | |
| return model(inputs, training=False) | |
| self.fast_predict = fast_predict | |
| print(f" 🔥 Warming up {display_name}...") | |
| dummy = tf.constant([[1, 2, 3]], dtype=tf.int32) | |
| _ = self.fast_predict(dummy) | |
| print(f" ✅ Compilation complete!") | |
| total, non_zero = count_parameters(model) | |
| self.total_params = total | |
| self.non_zero_params = non_zero | |
| self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0 | |
| self.n_heads = model.cfg.get('n_heads', 0) | |
| self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0)) | |
| def predict(self, input_ids): | |
| inputs = tf.constant([input_ids], dtype=tf.int32) | |
| logits = self.fast_predict(inputs) | |
| return logits[0, -1, :].numpy() | |
| def get_name(self): | |
| return self.display_name | |
| def get_info(self): | |
| info = f"{self.display_name}\n" | |
| info += f" Total params: {format_param_count(self.total_params)}\n" | |
| info += f" Attention heads: {self.n_heads}\n" | |
| info += f" FFN dimension: {self.ff_dim}\n" | |
| if self.sparsity > 1: | |
| info += f" Sparsity: {self.sparsity:.1f}%\n" | |
| return info | |
| MODEL_REGISTRY = [ | |
| ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None), | |
| ("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"), | |
| ("SAM-X-1-Mini 🚀 (ADVANCED!)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini_finetuned.weights.h5", "sam1_mini_finetuned_config.json"), | |
| ("SAM-X-1-Nano ⚡⚡", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano_finetuned.weights.h5", "sam1_nano_finetuned_config.json"), | |
| ] | |
| def estimate_prompt_complexity(prompt): | |
| prompt_lower = prompt.lower() | |
| complexity_score = 0 | |
| word_count = len(prompt.split()) | |
| if word_count > 100: | |
| complexity_score += 3 | |
| elif word_count > 50: | |
| complexity_score += 2 | |
| elif word_count > 20: | |
| complexity_score += 1 | |
| hard_keywords = ['analyze', 'explain', 'compare', 'evaluate', 'prove', 'derive', 'calculate', 'solve', 'reason', 'why', 'how does', 'complex', 'algorithm', 'mathematics', 'philosophy', 'theory', 'logic', 'detailed', 'comprehensive', 'thorough', 'in-depth'] | |
| for keyword in hard_keywords: | |
| if keyword in prompt_lower: | |
| complexity_score += 2 | |
| medium_keywords = ['write', 'create', 'generate', 'summarize', 'describe', 'list', 'what is', 'tell me', 'explain briefly'] | |
| for keyword in medium_keywords: | |
| if keyword in prompt_lower: | |
| complexity_score += 1 | |
| if any(word in prompt_lower for word in ['code', 'function', 'program', 'debug', 'implement']): | |
| complexity_score += 2 | |
| if any(word in prompt_lower for word in ['first', 'then', 'next', 'finally', 'step']): | |
| complexity_score += 1 | |
| question_marks = prompt.count('?') | |
| if question_marks > 1: | |
| complexity_score += 1 | |
| return complexity_score | |
| def select_model_auto(prompt, available_models_dict, user_available_models): | |
| complexity = estimate_prompt_complexity(prompt) | |
| accessible = {k: v for k, v in available_models_dict.items() if k in user_available_models} | |
| if not accessible: | |
| return None | |
| if complexity <= 2: | |
| preferred = "SAM-X-1-Nano ⚡⚡" | |
| fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"] | |
| elif complexity <= 5: | |
| preferred = "SAM-X-1-Mini 🚀 (ADVANCED!)" | |
| fallback_order = ["SAM-X-1-Nano ⚡⚡", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"] | |
| elif complexity <= 8: | |
| preferred = "SAM-X-1-Fast ⚡ (BETA)" | |
| fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Large", "SAM-X-1-Nano ⚡⚡"] | |
| else: | |
| preferred = "SAM-X-1-Large" | |
| fallback_order = ["SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Nano ⚡⚡"] | |
| if preferred in accessible: | |
| return accessible[preferred] | |
| for model_name in fallback_order: | |
| if model_name in accessible: | |
| return accessible[model_name] | |
| return list(accessible.values())[0] | |
| CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002" | |
| print("="*80) | |
| print("🤖 SAM-X-1 Multi-Model Chat Interface".center(80)) | |
| print("="*80) | |
| print(f"\n📦 Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}") | |
| config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json") | |
| tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json") | |
| with open(config_path, 'r') as f: | |
| base_config = json.load(f) | |
| print(f"✅ Base config loaded") | |
| base_model_config = {'vocab_size': base_config['vocab_size'], 'd_model': base_config['hidden_size'], 'n_heads': base_config['num_attention_heads'], 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'], 'dropout': base_config.get('dropout', 0.0), 'max_len': base_config['max_position_embeddings'], 'rope_theta': base_config['rope_theta'], 'n_layers': base_config['num_hidden_layers']} | |
| print("\n🔤 Recreating tokenizer...") | |
| tokenizer = Tokenizer.from_pretrained("gpt2") | |
| eos_token = "<|endoftext|>" | |
| eos_token_id = tokenizer.token_to_id(eos_token) | |
| if eos_token_id is None: | |
| tokenizer.add_special_tokens([eos_token]) | |
| eos_token_id = tokenizer.token_to_id(eos_token) | |
| custom_tokens = ["<think>", "<think/>"] | |
| for token in custom_tokens: | |
| if tokenizer.token_to_id(token) is None: | |
| tokenizer.add_special_tokens([token]) | |
| tokenizer.no_padding() | |
| tokenizer.enable_truncation(max_length=base_config['max_position_embeddings']) | |
| print(f"✅ Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})") | |
| print(f" EOS token: '{eos_token}' (ID: {eos_token_id})") | |
| if eos_token_id is None: | |
| raise ValueError("❌ Failed to set EOS token ID!") | |
| print("\n" + "="*80) | |
| print("📦 LOADING MODELS".center(80)) | |
| print("="*80) | |
| available_models = {} | |
| dummy_input = tf.zeros((1, 1), dtype=tf.int32) | |
| for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY: | |
| try: | |
| print(f"\n⏳ Loading: {display_name}") | |
| print(f" Repo: {repo_id}") | |
| print(f" Weights: {weights_filename}") | |
| weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename) | |
| if config_filename: | |
| print(f" Config: {config_filename}") | |
| custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) | |
| with open(custom_config_path, 'r') as f: | |
| model_config = json.load(f) | |
| print(f" 📐 Custom architecture: {model_config['n_heads']} heads") | |
| else: | |
| model_config = base_model_config.copy() | |
| model = SAM1Model(**model_config) | |
| model(dummy_input) | |
| model.load_weights(weights_path) | |
| model.trainable = False | |
| backend = KerasBackend(model, display_name, display_name) | |
| available_models[display_name] = backend | |
| print(f" ✅ Loaded successfully!") | |
| print(f" 📊 Parameters: {format_param_count(backend.total_params)}") | |
| except Exception as e: | |
| print(f" ⚠️ Failed to load: {e}") | |
| if not available_models: | |
| raise RuntimeError("❌ No models loaded!") | |
| print(f"\n✅ Successfully loaded {len(available_models)} model(s)") | |
| current_backend = list(available_models.values())[0] | |
| stop_generation = threading.Event() | |
| def generate_response_stream(prompt, temperature=0.7, backend=None, max_tokens=256): | |
| global stop_generation | |
| stop_generation.clear() | |
| if backend is None: | |
| backend = current_backend | |
| encoded_prompt = tokenizer.encode(prompt) | |
| input_ids = [i for i in encoded_prompt.ids if i != eos_token_id] | |
| generated = input_ids.copy() | |
| current_text = "" | |
| in_thinking = False | |
| max_len = backend.model.cfg['max_len'] | |
| start_time = time.time() | |
| tokens_generated = 0 | |
| decode_buffer = [] | |
| decode_every = 2 | |
| last_speed_check = start_time | |
| for step in range(max_tokens): | |
| if stop_generation.is_set(): | |
| elapsed = time.time() - start_time | |
| final_speed = tokens_generated / elapsed if elapsed > 0 else 0 | |
| yield "", False, -1, final_speed, True | |
| return | |
| current_input = generated[-max_len:] | |
| next_token_logits = backend.predict(current_input) | |
| if tokens_generated > 5 and tokens_generated % 10 == 0: | |
| current_time = time.time() | |
| elapsed_since_check = current_time - last_speed_check | |
| if elapsed_since_check > 0: | |
| recent_speed = 10 / elapsed_since_check | |
| if recent_speed > 25: | |
| decode_every = 8 | |
| elif recent_speed > 15: | |
| decode_every = 5 | |
| elif recent_speed > 8: | |
| decode_every = 3 | |
| else: | |
| decode_every = 2 | |
| last_speed_check = current_time | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| top_k = 5 | |
| top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:] | |
| top_k_logits = next_token_logits[top_k_indices] | |
| max_logit = np.max(top_k_logits) | |
| exp_logits = np.exp(top_k_logits - max_logit) | |
| probs = exp_logits / np.sum(exp_logits) | |
| next_token = top_k_indices[np.random.choice(top_k, p=probs)] | |
| else: | |
| next_token = np.argmax(next_token_logits) | |
| if next_token == eos_token_id: | |
| break | |
| generated.append(int(next_token)) | |
| decode_buffer.append(int(next_token)) | |
| tokens_generated += 1 | |
| should_decode = (len(decode_buffer) >= decode_every or step == max_tokens - 1) | |
| if should_decode: | |
| new_text = tokenizer.decode(generated[len(input_ids):]) | |
| if len(new_text) > len(current_text): | |
| new_chunk = new_text[len(current_text):] | |
| current_text = new_text | |
| if "<think>" in new_chunk: | |
| in_thinking = True | |
| elif "</think>" in new_chunk or "<think/>" in new_chunk: | |
| in_thinking = False | |
| elapsed = time.time() - start_time | |
| tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0 | |
| yield new_chunk, in_thinking, tokens_per_sec, tokens_per_sec, False | |
| decode_buffer = [] | |
| elapsed = time.time() - start_time | |
| final_tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0 | |
| yield "", False, final_tokens_per_sec, final_tokens_per_sec, False | |