Keeby-smilyai commited on
Commit
7f368dd
·
verified ·
1 Parent(s): 4e366fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +384 -739
app.py CHANGED
@@ -11,26 +11,371 @@ import json
11
  from abc import ABC, abstractmethod
12
  import time
13
  import threading
 
 
 
 
14
 
15
  # ==============================================================================
16
  # Performance Optimizations for CPU
17
  # ==============================================================================
18
- # Set TensorFlow to use fewer threads (better for 2vCPU)
19
  tf.config.threading.set_inter_op_parallelism_threads(1)
20
  tf.config.threading.set_intra_op_parallelism_threads(2)
21
-
22
- # Enable XLA compilation for faster execution
23
  tf.config.optimizer.set_jit(True)
24
-
25
- # Disable eager execution for better performance
26
  tf.config.run_functions_eagerly(False)
27
-
28
- # Memory optimization
29
  os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'
30
  os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # ==============================================================================
33
- # Model Architecture (Must match training code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # ==============================================================================
35
  @keras.saving.register_keras_serializable()
36
  class RotaryEmbedding(keras.layers.Layer):
@@ -47,7 +392,6 @@ class RotaryEmbedding(keras.layers.Layer):
47
  t = tf.range(self.max_len, dtype=tf.float32)
48
  freqs = tf.einsum("i,j->ij", t, inv_freq)
49
  emb = tf.concat([freqs, freqs], axis=-1)
50
-
51
  self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
52
  self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
53
  self.built_cache = True
@@ -62,10 +406,8 @@ class RotaryEmbedding(keras.layers.Layer):
62
  dtype = q.dtype
63
  cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
64
  sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
65
-
66
  q_rotated = (q * cos) + (self.rotate_half(q) * sin)
67
  k_rotated = (k * cos) + (self.rotate_half(k) * sin)
68
-
69
  return q_rotated, k_rotated
70
 
71
  def get_config(self):
@@ -73,7 +415,6 @@ class RotaryEmbedding(keras.layers.Layer):
73
  config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
74
  return config
75
 
76
-
77
  @keras.saving.register_keras_serializable()
78
  class RMSNorm(keras.layers.Layer):
79
  def __init__(self, epsilon=1e-5, **kwargs):
@@ -92,7 +433,6 @@ class RMSNorm(keras.layers.Layer):
92
  config.update({"epsilon": self.epsilon})
93
  return config
94
 
95
-
96
  @keras.saving.register_keras_serializable()
97
  class TransformerBlock(keras.layers.Layer):
98
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
@@ -105,68 +445,47 @@ class TransformerBlock(keras.layers.Layer):
105
  self.rope_theta = rope_theta
106
  self.head_dim = d_model // n_heads
107
  self.layer_idx = layer_idx
108
-
109
  self.pre_attn_norm = RMSNorm()
110
  self.pre_ffn_norm = RMSNorm()
111
-
112
  self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
113
  self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
114
  self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
115
  self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
116
-
117
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
118
-
119
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
120
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
121
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
122
-
123
  self.dropout = keras.layers.Dropout(dropout)
124
 
125
  def call(self, x, training=None):
126
  B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
127
  dtype = x.dtype
128
-
129
  res = x
130
  y = self.pre_attn_norm(x)
131
-
132
  q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
133
  k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
134
  v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
135
-
136
  q, k = self.rope(q, k)
137
-
138
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
139
-
140
- mask = tf.where(
141
- tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
142
- tf.constant(-1e9, dtype=dtype),
143
- tf.constant(0.0, dtype=dtype)
144
- )
145
  scores += mask
146
  attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
147
-
148
  attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
149
  x = res + self.dropout(self.out_proj(attn), training=training)
150
-
151
  res = x
152
  y = self.pre_ffn_norm(x)
153
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
154
-
155
  return res + self.dropout(ffn, training=training)
156
 
157
  def get_config(self):
158
  config = super().get_config()
159
- config.update({
160
- "d_model": self.d_model,
161
- "n_heads": self.n_heads,
162
- "ff_dim": self.ff_dim,
163
- "dropout": self.dropout_rate,
164
- "max_len": self.max_len,
165
- "rope_theta": self.rope_theta,
166
- "layer_idx": self.layer_idx
167
- })
168
- return config
169
 
 
 
 
 
170
 
171
  @keras.saving.register_keras_serializable()
172
  class SAM1Model(keras.Model):
@@ -178,33 +497,20 @@ class SAM1Model(keras.Model):
178
  self.cfg = kwargs
179
  else:
180
  self.cfg = kwargs.get('cfg', kwargs)
181
-
182
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
183
-
184
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
185
- block_args = {
186
- 'd_model': self.cfg['d_model'],
187
- 'n_heads': self.cfg['n_heads'],
188
- 'ff_dim': ff_dim,
189
- 'dropout': self.cfg['dropout'],
190
- 'max_len': self.cfg['max_len'],
191
- 'rope_theta': self.cfg['rope_theta']
192
- }
193
-
194
  self.blocks = []
195
  for i in range(self.cfg['n_layers']):
196
  block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
197
  self.blocks.append(block)
198
-
199
  self.norm = RMSNorm(name="final_norm")
200
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
201
 
202
  def call(self, input_ids, training=None):
203
  x = self.embed(input_ids)
204
-
205
  for block in self.blocks:
206
  x = block(x, training=training)
207
-
208
  return self.lm_head(self.norm(x))
209
 
210
  def get_config(self):
@@ -212,25 +518,16 @@ class SAM1Model(keras.Model):
212
  base_config['config'] = self.cfg
213
  return base_config
214
 
215
-
216
- # ==============================================================================
217
- # Helper Functions
218
- # ==============================================================================
219
  def count_parameters(model):
220
- """Count total and non-zero parameters in model."""
221
  total_params = 0
222
  non_zero_params = 0
223
-
224
  for weight in model.weights:
225
  w = weight.numpy()
226
  total_params += w.size
227
  non_zero_params += np.count_nonzero(w)
228
-
229
  return total_params, non_zero_params
230
 
231
-
232
  def format_param_count(count):
233
- """Format parameter count in human readable format."""
234
  if count >= 1e9:
235
  return f"{count/1e9:.2f}B"
236
  elif count >= 1e6:
@@ -240,53 +537,34 @@ def format_param_count(count):
240
  else:
241
  return str(count)
242
 
243
-
244
- # ==============================================================================
245
- # Model Backend Interface
246
- # ==============================================================================
247
  class ModelBackend(ABC):
248
  @abstractmethod
249
  def predict(self, input_ids):
250
  pass
251
-
252
  @abstractmethod
253
  def get_name(self):
254
  pass
255
-
256
  @abstractmethod
257
  def get_info(self):
258
  pass
259
 
260
-
261
  class KerasBackend(ModelBackend):
262
  def __init__(self, model, name, display_name):
263
  self.model = model
264
  self.name = name
265
  self.display_name = display_name
266
-
267
- # Pre-compile predict function for faster inference
268
- @tf.function(
269
- input_signature=[tf.TensorSpec(shape=[1, None], dtype=tf.int32)],
270
- jit_compile=True
271
- )
272
  def fast_predict(inputs):
273
  return model(inputs, training=False)
274
-
275
  self.fast_predict = fast_predict
276
-
277
- # Warm up compilation with dummy input
278
  print(f" 🔥 Warming up {display_name}...")
279
  dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
280
  _ = self.fast_predict(dummy)
281
  print(f" ✅ Compilation complete!")
282
-
283
- # Count parameters
284
  total, non_zero = count_parameters(model)
285
  self.total_params = total
286
  self.non_zero_params = non_zero
287
  self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0
288
-
289
- # Calculate actual model config for speed estimation
290
  self.n_heads = model.cfg.get('n_heads', 0)
291
  self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
292
 
@@ -307,10 +585,6 @@ class KerasBackend(ModelBackend):
307
  info += f" Sparsity: {self.sparsity:.1f}%\n"
308
  return info
309
 
310
-
311
- # ==============================================================================
312
- # EASY MODEL REGISTRY - ADD YOUR MODELS HERE!
313
- # ==============================================================================
314
  MODEL_REGISTRY = [
315
  ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
316
  ("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"),
@@ -318,22 +592,9 @@ MODEL_REGISTRY = [
318
  ("SAM-X-1-Nano ⚡⚡", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano_finetuned.weights.h5", "sam1_nano_finetuned_config.json"),
319
  ]
320
 
321
- # Model complexity scores for auto-selection (higher = more capable)
322
- MODEL_COMPLEXITY = {
323
- "SAM-X-1-Nano ⚡⚡": 1,
324
- "SAM-X-1-Mini 🚀 (ADVANCED!)": 2,
325
- "SAM-X-1-Fast ⚡ (BETA)": 3,
326
- "SAM-X-1-Large": 4
327
- }
328
-
329
  def estimate_prompt_complexity(prompt):
330
- """Estimate prompt complexity to choose appropriate model."""
331
  prompt_lower = prompt.lower()
332
-
333
- # Count complexity indicators
334
  complexity_score = 0
335
-
336
- # Length-based complexity
337
  word_count = len(prompt.split())
338
  if word_count > 100:
339
  complexity_score += 3
@@ -341,52 +602,28 @@ def estimate_prompt_complexity(prompt):
341
  complexity_score += 2
342
  elif word_count > 20:
343
  complexity_score += 1
344
-
345
- # Hard reasoning keywords (need Large/Fast)
346
- hard_keywords = [
347
- 'analyze', 'explain', 'compare', 'evaluate', 'prove', 'derive',
348
- 'calculate', 'solve', 'reason', 'why', 'how does', 'complex',
349
- 'algorithm', 'mathematics', 'philosophy', 'theory', 'logic',
350
- 'detailed', 'comprehensive', 'thorough', 'in-depth'
351
- ]
352
  for keyword in hard_keywords:
353
  if keyword in prompt_lower:
354
  complexity_score += 2
355
-
356
- # Medium complexity keywords (need Mini/Fast)
357
- medium_keywords = [
358
- 'write', 'create', 'generate', 'summarize', 'describe',
359
- 'list', 'what is', 'tell me', 'explain briefly'
360
- ]
361
  for keyword in medium_keywords:
362
  if keyword in prompt_lower:
363
  complexity_score += 1
364
-
365
- # Code-related (usually complex)
366
  if any(word in prompt_lower for word in ['code', 'function', 'program', 'debug', 'implement']):
367
  complexity_score += 2
368
-
369
- # Multi-step or multi-part questions
370
  if any(word in prompt_lower for word in ['first', 'then', 'next', 'finally', 'step']):
371
  complexity_score += 1
372
-
373
- # Questions with multiple parts
374
  question_marks = prompt.count('?')
375
  if question_marks > 1:
376
  complexity_score += 1
377
-
378
  return complexity_score
379
 
380
- def select_model_auto(prompt, available_models):
381
- """Automatically select best model based on prompt complexity."""
382
  complexity = estimate_prompt_complexity(prompt)
383
-
384
- # Map complexity to model choice
385
- # 0-2: Simple questions -> Nano (fastest)
386
- # 3-5: Medium questions -> Mini (balanced)
387
- # 6-8: Complex questions -> Fast (capable)
388
- # 9+: Very complex -> Large (most capable)
389
-
390
  if complexity <= 2:
391
  preferred = "SAM-X-1-Nano ⚡⚡"
392
  fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"]
@@ -399,208 +636,114 @@ def select_model_auto(prompt, available_models):
399
  else:
400
  preferred = "SAM-X-1-Large"
401
  fallback_order = ["SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Nano ⚡⚡"]
402
-
403
- # Try preferred model first
404
- if preferred in available_models:
405
- print(f" 🎯 Auto-selected {preferred} (complexity: {complexity})")
406
- return available_models[preferred]
407
-
408
- # Fallback to next best available
409
  for model_name in fallback_order:
410
- if model_name in available_models:
411
- print(f" 🎯 Auto-selected {model_name} (fallback, complexity: {complexity})")
412
- return available_models[model_name]
413
-
414
- # Last resort: return any available model
415
- return list(available_models.values())[0]
416
 
417
- # ==============================================================================
418
- # Load Models
419
- # ==============================================================================
420
  CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
421
-
422
  print("="*80)
423
  print("🤖 SAM-X-1 Multi-Model Chat Interface".center(80))
424
  print("="*80)
425
-
426
- # Download config and tokenizer
427
  print(f"\n📦 Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}")
428
  config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
429
  tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json")
430
-
431
- # Load config
432
  with open(config_path, 'r') as f:
433
  base_config = json.load(f)
434
-
435
  print(f"✅ Base config loaded")
436
-
437
- # Build base model config
438
- base_model_config = {
439
- 'vocab_size': base_config['vocab_size'],
440
- 'd_model': base_config['hidden_size'],
441
- 'n_heads': base_config['num_attention_heads'],
442
- 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'],
443
- 'dropout': base_config.get('dropout', 0.0),
444
- 'max_len': base_config['max_position_embeddings'],
445
- 'rope_theta': base_config['rope_theta'],
446
- 'n_layers': base_config['num_hidden_layers']
447
- }
448
-
449
- # ==============================================================================
450
- # FIX: Proper EOS token handling
451
- # ==============================================================================
452
  print("\n🔤 Recreating tokenizer...")
453
  tokenizer = Tokenizer.from_pretrained("gpt2")
454
-
455
- # GPT-2's actual EOS token is "<|endoftext|>"
456
  eos_token = "<|endoftext|>"
457
  eos_token_id = tokenizer.token_to_id(eos_token)
458
-
459
  if eos_token_id is None:
460
- # Fallback to adding it
461
  tokenizer.add_special_tokens([eos_token])
462
  eos_token_id = tokenizer.token_to_id(eos_token)
463
-
464
- # Add custom tokens
465
  custom_tokens = ["<think>", "<think/>"]
466
  for token in custom_tokens:
467
  if tokenizer.token_to_id(token) is None:
468
  tokenizer.add_special_tokens([token])
469
-
470
  tokenizer.no_padding()
471
  tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
472
-
473
  print(f"✅ Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})")
474
  print(f" EOS token: '{eos_token}' (ID: {eos_token_id})")
475
-
476
- # Verify EOS token is valid
477
  if eos_token_id is None:
478
- raise ValueError("❌ Failed to set EOS token ID! Check tokenizer setup.")
479
-
480
- # Load all models from registry
481
  print("\n" + "="*80)
482
  print("📦 LOADING MODELS".center(80))
483
  print("="*80)
484
-
485
  available_models = {}
486
  dummy_input = tf.zeros((1, 1), dtype=tf.int32)
487
-
488
  for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
489
  try:
490
  print(f"\n⏳ Loading: {display_name}")
491
  print(f" Repo: {repo_id}")
492
  print(f" Weights: {weights_filename}")
493
-
494
- # Download weights
495
  weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
496
-
497
- # Load custom config if specified (for pruned models)
498
  if config_filename:
499
  print(f" Config: {config_filename}")
500
  custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
501
  with open(custom_config_path, 'r') as f:
502
  model_config = json.load(f)
503
- print(f" 📐 Custom architecture: {model_config['n_heads']} heads, {int(model_config['d_model'] * model_config['ff_mult'])} FFN dim")
504
  else:
505
  model_config = base_model_config.copy()
506
-
507
- # Create model with appropriate config
508
  model = SAM1Model(**model_config)
509
  model(dummy_input)
510
  model.load_weights(weights_path)
511
  model.trainable = False
512
-
513
- # Create backend
514
  backend = KerasBackend(model, display_name, display_name)
515
  available_models[display_name] = backend
516
-
517
- # Print stats
518
  print(f" ✅ Loaded successfully!")
519
  print(f" 📊 Parameters: {format_param_count(backend.total_params)}")
520
- print(f" 📊 Attention heads: {backend.n_heads}")
521
- print(f" 📊 FFN dimension: {backend.ff_dim}")
522
-
523
  except Exception as e:
524
  print(f" ⚠️ Failed to load: {e}")
525
- print(f" Skipping {display_name}...")
526
-
527
  if not available_models:
528
- raise RuntimeError("❌ No models loaded! Check your MODEL_REGISTRY configuration.")
529
-
530
  print(f"\n✅ Successfully loaded {len(available_models)} model(s)")
531
- print(f" Device: {'GPU' if len(tf.config.list_physical_devices('GPU')) > 0 else 'CPU'}")
532
-
533
  current_backend = list(available_models.values())[0]
534
-
535
- # Global stop flag
536
  stop_generation = threading.Event()
537
 
538
-
539
- # ==============================================================================
540
- # FIX: Improved generation function with better stop handling
541
- # ==============================================================================
542
  def generate_response_stream(prompt, temperature=0.7, backend=None, max_tokens=256):
543
- """Generate response and yield tokens one by one for streaming."""
544
  global stop_generation
545
  stop_generation.clear()
546
-
547
  if backend is None:
548
  backend = current_backend
549
-
550
- # Encode prompt
551
  encoded_prompt = tokenizer.encode(prompt)
552
  input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
553
  generated = input_ids.copy()
554
-
555
  current_text = ""
556
  in_thinking = False
557
-
558
- # Get max_len from the backend's model config
559
  max_len = backend.model.cfg['max_len']
560
-
561
- # Track timing
562
  start_time = time.time()
563
  tokens_generated = 0
564
-
565
- # *** DYNAMIC DECODE BATCHING: Adjust based on generation speed ***
566
  decode_buffer = []
567
- decode_every = 2 # Start conservative
568
  last_speed_check = start_time
569
-
570
- # Generate tokens
571
  for step in range(max_tokens):
572
- # *** FIX: Check stop flag FIRST before any processing ***
573
  if stop_generation.is_set():
574
- print(f" 🛑 Stop requested at token {tokens_generated}")
575
- # Calculate final speed
576
  elapsed = time.time() - start_time
577
  final_speed = tokens_generated / elapsed if elapsed > 0 else 0
578
- yield "", False, -1, final_speed, True # Added stopped flag
579
  return
580
-
581
  current_input = generated[-max_len:]
582
-
583
- # Get logits from selected backend
584
  next_token_logits = backend.predict(current_input)
585
-
586
- # *** DYNAMIC BATCHING: Adjust decode_every based on speed ***
587
- # Check speed every 10 tokens after warmup
588
  if tokens_generated > 5 and tokens_generated % 10 == 0:
589
  current_time = time.time()
590
  elapsed_since_check = current_time - last_speed_check
591
  if elapsed_since_check > 0:
592
  recent_speed = 10 / elapsed_since_check
593
- # Adaptive batching: faster models can batch more
594
  if recent_speed > 25:
595
- decode_every = 8 # Very fast (Nano)
596
  elif recent_speed > 15:
597
- decode_every = 5 # Fast (Mini)
598
  elif recent_speed > 8:
599
- decode_every = 3 # Medium (Fast)
600
  else:
601
- decode_every = 2 # Slow (Large)
602
  last_speed_check = current_time
603
-
604
  if temperature > 0:
605
  next_token_logits = next_token_logits / temperature
606
  top_k = 5
@@ -612,524 +755,26 @@ def generate_response_stream(prompt, temperature=0.7, backend=None, max_tokens=2
612
  next_token = top_k_indices[np.random.choice(top_k, p=probs)]
613
  else:
614
  next_token = np.argmax(next_token_logits)
615
-
616
- # *** FIX: Check for EOS token IMMEDIATELY and break ***
617
  if next_token == eos_token_id:
618
- print(f" 🛑 EOS token detected at position {tokens_generated}")
619
  break
620
-
621
  generated.append(int(next_token))
622
  decode_buffer.append(int(next_token))
623
  tokens_generated += 1
624
-
625
- # Decode in batches for better performance
626
- should_decode = (len(decode_buffer) >= decode_every or
627
- step == max_tokens - 1)
628
-
629
  if should_decode:
630
  new_text = tokenizer.decode(generated[len(input_ids):])
631
  if len(new_text) > len(current_text):
632
  new_chunk = new_text[len(current_text):]
633
  current_text = new_text
634
-
635
  if "<think>" in new_chunk:
636
  in_thinking = True
637
  elif "</think>" in new_chunk or "<think/>" in new_chunk:
638
  in_thinking = False
639
-
640
- # Calculate tokens/sec
641
  elapsed = time.time() - start_time
642
  tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0
643
-
644
  yield new_chunk, in_thinking, tokens_per_sec, tokens_per_sec, False
645
  decode_buffer = []
646
-
647
- # Final stats
648
  elapsed = time.time() - start_time
649
  final_tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0
650
  yield "", False, final_tokens_per_sec, final_tokens_per_sec, False
651
 
652
-
653
- # ==============================================================================
654
- # Gradio Interface
655
- # ==============================================================================
656
- if __name__ == "__main__":
657
- import gradio as gr
658
-
659
- custom_css = """
660
- .chat-container {
661
- height: 600px;
662
- overflow-y: auto;
663
- padding: 20px;
664
- background: #ffffff;
665
- }
666
-
667
- .user-message {
668
- background: #f7f7f8;
669
- padding: 16px;
670
- margin: 12px 0;
671
- border-radius: 8px;
672
- }
673
-
674
- .assistant-message {
675
- background: #ffffff;
676
- padding: 16px;
677
- margin: 12px 0;
678
- border-radius: 8px;
679
- border-left: 3px solid #10a37f;
680
- }
681
-
682
- .message-content {
683
- color: #353740;
684
- line-height: 1.6;
685
- font-size: 15px;
686
- }
687
-
688
- .message-header {
689
- font-weight: 600;
690
- margin-bottom: 8px;
691
- color: #353740;
692
- font-size: 14px;
693
- }
694
-
695
- .thinking-content {
696
- color: #6b7280;
697
- font-style: italic;
698
- border-left: 3px solid #d1d5db;
699
- padding-left: 12px;
700
- margin: 8px 0;
701
- background: #f9fafb;
702
- padding: 8px 12px;
703
- border-radius: 4px;
704
- }
705
-
706
- .input-row {
707
- background: #ffffff;
708
- padding: 12px;
709
- border-radius: 8px;
710
- margin-top: 12px;
711
- border: 1px solid #e5e7eb;
712
- }
713
-
714
- .gradio-container {
715
- max-width: 900px !important;
716
- margin: auto !important;
717
- }
718
-
719
- .announcement-banner {
720
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
721
- color: white;
722
- padding: 20px 28px;
723
- border-radius: 12px;
724
- margin-bottom: 20px;
725
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
726
- text-align: center;
727
- font-size: 16px;
728
- font-weight: 500;
729
- animation: slideIn 0.5s ease-out;
730
- line-height: 1.6;
731
- }
732
-
733
- @keyframes slideIn {
734
- from {
735
- opacity: 0;
736
- transform: translateY(-20px);
737
- }
738
- to {
739
- opacity: 1;
740
- transform: translateY(0);
741
- }
742
- }
743
-
744
- .announcement-banner strong {
745
- font-weight: 700;
746
- font-size: 18px;
747
- }
748
-
749
- .settings-panel {
750
- background: #f9fafb;
751
- padding: 16px;
752
- border-radius: 8px;
753
- margin-bottom: 12px;
754
- border: 1px solid #e5e7eb;
755
- }
756
-
757
- .model-info {
758
- background: #f0f9ff;
759
- border: 1px solid #bae6fd;
760
- padding: 12px;
761
- border-radius: 8px;
762
- margin-top: 8px;
763
- font-size: 13px;
764
- font-family: monospace;
765
- white-space: pre-line;
766
- }
767
-
768
- .speed-indicator {
769
- background: #dcfce7;
770
- border: 1px solid #86efac;
771
- padding: 8px 12px;
772
- border-radius: 6px;
773
- margin-top: 8px;
774
- font-size: 14px;
775
- font-weight: 600;
776
- color: #166534;
777
- text-align: center;
778
- }
779
-
780
- /* Circular Send Button */
781
- .send-btn-wrapper {
782
- display: flex;
783
- gap: 8px;
784
- align-items: center;
785
- }
786
-
787
- .circular-btn {
788
- width: 48px !important;
789
- height: 48px !important;
790
- min-width: 48px !important;
791
- border-radius: 50% !important;
792
- padding: 0 !important;
793
- display: flex !important;
794
- align-items: center !important;
795
- justify-content: center !important;
796
- font-size: 20px !important;
797
- box-shadow: 0 2px 8px rgba(0,0,0,0.15) !important;
798
- transition: all 0.2s ease !important;
799
- }
800
-
801
- .circular-btn:hover:not(:disabled) {
802
- transform: scale(1.05) !important;
803
- box-shadow: 0 4px 12px rgba(0,0,0,0.2) !important;
804
- }
805
-
806
- .circular-btn:active:not(:disabled) {
807
- transform: scale(0.95) !important;
808
- }
809
-
810
- .send-btn {
811
- background: linear-gradient(135deg, #10a37f 0%, #0d8c6c 100%) !important;
812
- border: none !important;
813
- }
814
-
815
- .stop-btn {
816
- background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%) !important;
817
- border: none !important;
818
- }
819
-
820
- .circular-btn:disabled {
821
- opacity: 0.4 !important;
822
- cursor: not-allowed !important;
823
- transform: none !important;
824
- }
825
- """
826
-
827
- def format_message_html(role, content, show_thinking=True, show_raw=False):
828
- """Format a single message as HTML."""
829
- role_class = "user-message" if role == "user" else "assistant-message"
830
- role_name = "You" if role == "user" else "SAM-X-1"
831
-
832
- thinking = ""
833
- answer = ""
834
-
835
- if "<think>" in content:
836
- parts = content.split("<think>", 1)
837
- before_think = parts[0].strip()
838
-
839
- if len(parts) > 1:
840
- after_think = parts[1]
841
-
842
- if "</think>" in after_think:
843
- think_parts = after_think.split("</think>", 1)
844
- thinking = think_parts[0].strip()
845
- answer = (before_think + " " + think_parts[1]).strip()
846
- elif "<think/>" in after_think:
847
- think_parts = after_think.split("<think/>", 1)
848
- thinking = think_parts[0].strip()
849
- answer = (before_think + " " + think_parts[1]).strip()
850
- else:
851
- thinking = after_think.strip()
852
- answer = before_think
853
- else:
854
- answer = before_think
855
- else:
856
- answer = content
857
-
858
- html = f'<div class="{role_class}">'
859
- html += f'<div class="message-header">{role_name}</div>'
860
- html += f'<div class="message-content">'
861
-
862
- if thinking and show_thinking:
863
- html += f'<div class="thinking-content">💭 {thinking}</div>'
864
-
865
- if answer:
866
- html += f'<div>{answer}</div>'
867
-
868
- # Add raw response debug view
869
- if show_raw and role == "assistant":
870
- # Escape HTML and show special tokens
871
- raw_content = content.replace("<", "&lt;").replace(">", "&gt;")
872
- raw_content = raw_content.replace("&lt;endoftext&gt;", '<span style="background: #fef3c7; color: #92400e; padding: 2px 6px; border-radius: 3px; font-weight: bold;">⚠️ &lt;endoftext&gt;</span>')
873
- raw_content = raw_content.replace("&lt;think&gt;", '<span style="background: #dbeafe; color: #1e40af; padding: 2px 6px; border-radius: 3px;">🤔 &lt;think&gt;</span>')
874
- raw_content = raw_content.replace("&lt;/think&gt;", '<span style="background: #dbeafe; color: #1e40af; padding: 2px 6px; border-radius: 3px;">✅ &lt;/think&gt;</span>')
875
- raw_content = raw_content.replace("&lt;think/&gt;", '<span style="background: #dbeafe; color: #1e40af; padding: 2px 6px; border-radius: 3px;">✅ &lt;think/&gt;</span>')
876
-
877
- html += f'''
878
- <div style="margin-top: 12px; padding: 12px; background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 6px; font-family: monospace; font-size: 12px; color: #374151;">
879
- <div style="font-weight: 600; margin-bottom: 6px; color: #6b7280;">🔍 Raw Response (Debug):</div>
880
- <div style="white-space: pre-wrap; word-break: break-all;">{raw_content}</div>
881
- </div>
882
- '''
883
-
884
- html += '</div></div>'
885
- return html
886
-
887
- def render_history(history, show_thinking, show_raw=False):
888
- """Render chat history as HTML."""
889
- html = ""
890
- for msg in history:
891
- html += format_message_html(msg["role"], msg["content"], show_thinking, show_raw)
892
- return html
893
-
894
- # ==============================================================================
895
- # Simplified send_message handler with separate buttons
896
- # ==============================================================================
897
- def send_message(message, show_thinking, temperature, model_choice, max_tokens, show_raw):
898
- global stop_generation
899
- stop_generation.clear()
900
-
901
- if not message.strip():
902
- return "", "", "⚡ 0.0 tok/s", gr.update(interactive=True), gr.update(interactive=False)
903
-
904
- # Disable send button, enable stop button
905
- yield "", "", "⚡ Generating...", gr.update(interactive=False), gr.update(interactive=True)
906
-
907
- # Switch backend based on selection (or auto-select)
908
- if model_choice == "🤖 Auto (Smart Selection)":
909
- backend = select_model_auto(message, available_models)
910
- model_name = backend.get_name()
911
- yield "", f"<div style='background: #dbeafe; padding: 12px; border-radius: 8px; margin: 8px 0; border-left: 3px solid #3b82f6;'><strong>🤖 Auto-selected:</strong> {model_name}</div>", "⚡ Generating...", gr.update(interactive=False), gr.update(interactive=True)
912
- else:
913
- backend = available_models[model_choice]
914
-
915
- # Create single-turn history
916
- history = [{"role": "user", "content": message}]
917
-
918
- # Show user message immediately
919
- yield "", render_history(history, show_thinking, show_raw), "⚡ Generating...", gr.update(interactive=False), gr.update(interactive=True)
920
-
921
- # Generate prompt (single turn, no history)
922
- prompt = f"User: {message}\nSam: <think>"
923
-
924
- # Start assistant message
925
- history.append({"role": "assistant", "content": "<think>"})
926
-
927
- # Stream response
928
- last_tokens_per_sec = 0
929
- was_stopped = False
930
-
931
- for chunk_data in generate_response_stream(prompt, temperature, backend, max_tokens):
932
- if len(chunk_data) == 5: # New format with stopped flag
933
- new_chunk, in_thinking, tokens_per_sec, avg_tokens_per_sec, stopped = chunk_data
934
-
935
- if stopped:
936
- was_stopped = True
937
- print(" ✅ Generation stopped successfully")
938
- break
939
-
940
- if new_chunk: # Only update if there's actual content
941
- history[-1]["content"] += new_chunk
942
-
943
- last_tokens_per_sec = avg_tokens_per_sec
944
-
945
- # Update UI on every chunk - keep stop button enabled
946
- speed_text = f"⚡ {tokens_per_sec:.1f} tok/s"
947
- yield "", render_history(history, show_thinking, show_raw), speed_text, gr.update(interactive=False), gr.update(interactive=True)
948
-
949
- # Final yield - enable send button, disable stop button
950
- if was_stopped:
951
- final_speed = f"🛑 Stopped at {last_tokens_per_sec:.1f} tok/s"
952
- else:
953
- final_speed = f"✅ {last_tokens_per_sec:.1f} tok/s (avg)"
954
-
955
- print(f" 📊 Final speed: {final_speed}")
956
- yield "", render_history(history, show_thinking, show_raw), final_speed, gr.update(interactive=True), gr.update(interactive=False)
957
-
958
- def stop_generation_handler():
959
- """Handle stop button click."""
960
- global stop_generation
961
- print(" 🛑 Stop button clicked - setting stop flag")
962
- stop_generation.set()
963
- return "🛑 Stopping...", gr.update(interactive=False), gr.update(interactive=False)
964
-
965
- def clear_chat():
966
- """Clear chat and reset UI."""
967
- return "", "⚡ Ready", gr.update(interactive=True), gr.update(interactive=False)
968
-
969
- def update_raw_view(history, show_thinking, show_raw):
970
- """Update the chat display when raw checkbox is toggled."""
971
- return render_history(history, show_thinking, show_raw)
972
-
973
- # Create Gradio interface
974
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="slate")) as demo:
975
- # Announcement Banner
976
- gr.HTML("""
977
- <div class="announcement-banner">
978
- 🎉 <strong>SAM-X-1 V2.2 IS HERE!</strong> 🚀<br>
979
- ✨ <strong>NEW:</strong> Auto Model Selection - Let AI pick the perfect model for your task!<br>
980
- ⚡ <strong>NEW:</strong> Dynamic Batching - Up to 4x faster UI updates on Nano & Mini!<br>
981
- 🔥 <strong>TRY IT NOW:</strong> Use "Auto" mode and watch it intelligently choose Nano for speed or Large for complexity!<br>
982
- 💎 <strong>Nano & Mini models are BLAZING fast</strong> - Perfect for quick questions and coding tasks!
983
- </div>
984
- """)
985
-
986
- gr.Markdown("# 🤖 SAM-X-1 Fast Chat (No History)")
987
-
988
- # Settings panel
989
- with gr.Accordion("⚙️ Settings", open=False):
990
- with gr.Row():
991
- model_selector = gr.Dropdown(
992
- choices=["🤖 Auto (Smart Selection)"] + list(available_models.keys()),
993
- value="🤖 Auto (Smart Selection)",
994
- label="Model Selection",
995
- info="Auto picks the best model for your prompt"
996
- )
997
-
998
- max_tokens_slider = gr.Slider(
999
- minimum=64,
1000
- maximum=512,
1001
- value=256,
1002
- step=64,
1003
- label="Max Tokens",
1004
- info="Lower = Faster generation"
1005
- )
1006
-
1007
- with gr.Row():
1008
- temperature_slider = gr.Slider(
1009
- minimum=0.0,
1010
- maximum=2.0,
1011
- value=0.7,
1012
- step=0.1,
1013
- label="Temperature",
1014
- info="Higher = more creative, Lower = more focused"
1015
- )
1016
-
1017
- with gr.Row():
1018
- show_thinking_checkbox = gr.Checkbox(
1019
- label="Show Thinking Process",
1020
- value=True,
1021
- info="Display model's reasoning"
1022
- )
1023
- show_raw_checkbox = gr.Checkbox(
1024
- label="Show Raw Response (Debug)",
1025
- value=False,
1026
- info="See all special tokens including <|endoftext|>"
1027
- )
1028
-
1029
- # Speed indicator
1030
- speed_display = gr.Textbox(
1031
- label="Generation Speed",
1032
- value="⚡ Ready",
1033
- interactive=False,
1034
- elem_classes=["speed-indicator"]
1035
- )
1036
-
1037
- # Chat display
1038
- chat_html = gr.HTML(value="", elem_classes=["chat-container"])
1039
-
1040
- # Input area with separate send and stop buttons
1041
- with gr.Row(elem_classes=["input-row"]):
1042
- msg_input = gr.Textbox(
1043
- placeholder="Ask me anything...",
1044
- show_label=False,
1045
- container=False,
1046
- scale=8
1047
- )
1048
- with gr.Column(scale=1, min_width=120):
1049
- with gr.Row():
1050
- send_btn = gr.Button("▶", variant="primary", elem_classes=["circular-btn", "send-btn"], interactive=True)
1051
- stop_btn = gr.Button("⏹", variant="stop", elem_classes=["circular-btn", "stop-btn"], interactive=False)
1052
-
1053
- with gr.Row():
1054
- clear_btn = gr.Button("🗑️ Clear", size="sm")
1055
-
1056
- gr.Markdown("""
1057
- ### 🎯 Try These Examples with Auto Mode:
1058
-
1059
- **Simple (→ Nano):**
1060
- - "Hi, how are you?"
1061
- - "What is Python?"
1062
- - "Tell me a joke"
1063
-
1064
- **Medium (→ Mini):**
1065
- - "Write a short story about a robot"
1066
- - "Summarize the benefits of exercise"
1067
- - "Create a simple Python function to sort a list"
1068
-
1069
- **Complex (→ Fast):**
1070
- - "Analyze the differences between procedural and object-oriented programming"
1071
- - "Compare and contrast democracy and authoritarianism"
1072
- - "Explain how neural networks learn with backpropagation"
1073
-
1074
- **Very Hard (→ Large):**
1075
- - "Prove why the Pythagorean theorem works using geometric reasoning"
1076
- - "Derive the formula for compound interest step by step"
1077
- - "Explain the philosophical implications of Gödel's incompleteness theorems"
1078
-
1079
- ### 💡 Speed Optimization Tips:
1080
- - **Auto mode (Default)**: Balances speed and quality automatically
1081
- - **Manual Nano**: 30-40 tok/s - Best for simple questions
1082
- - **Manual Mini**: 20-30 tok/s - Great for most tasks
1083
- - **Manual Fast**: 15-20 tok/s - Good for complex reasoning
1084
- - **Manual Large**: 10-15 tok/s - Use only for hardest problems
1085
- - **Temperature = 0**: Greedy decoding (fastest, deterministic)
1086
- - **Lower max tokens**: Stop generation earlier
1087
-
1088
- ### ⚡ V2.2 Features:
1089
- - ✅ **Smart Auto-Selection** - AI picks the right model for your prompt
1090
- - ✅ **Dynamic Decode Batching** - Adjusts from 2-8 tokens based on speed
1091
- - ✅ **Faster UI Updates** - Nano batches 8 tokens = 4x smoother experience
1092
- - ✅ **Complexity Analysis** - Examines length, keywords, code, multi-step questions
1093
- - ✅ **Instant Stop Button** - Interrupt generation with no delay
1094
- - ✅ **Debug Mode** - See all special tokens in raw view
1095
-
1096
- ### 🎯 Expected Speed (2vCPU):
1097
- - **Nano**: 30-40 tok/s (batch: 8) ⚡⚡
1098
- - **Mini**: 20-30 tok/s (batch: 5) 🚀
1099
- - **Fast**: 15-20 tok/s (batch: 3) ⚡
1100
- - **Large**: 10-15 tok/s (batch: 2) 💎
1101
-
1102
- ### 🚀 What's New:
1103
- - **V2.2**: Auto model selection + Dynamic batching
1104
- - **V2.1**: Separate Send/Stop buttons + EOS fixes + Debug view
1105
- - **V2.0**: Multi-model support + Speed optimizations
1106
- """)
1107
-
1108
- # Event handlers
1109
- send_outputs = [msg_input, chat_html, speed_display, send_btn, stop_btn]
1110
-
1111
- # Send button
1112
- send_btn.click(
1113
- send_message,
1114
- inputs=[msg_input, show_thinking_checkbox, temperature_slider, model_selector, max_tokens_slider, show_raw_checkbox],
1115
- outputs=send_outputs
1116
- )
1117
-
1118
- msg_input.submit(
1119
- send_message,
1120
- inputs=[msg_input, show_thinking_checkbox, temperature_slider, model_selector, max_tokens_slider, show_raw_checkbox],
1121
- outputs=send_outputs
1122
- )
1123
-
1124
- # Stop button
1125
- stop_btn.click(
1126
- stop_generation_handler,
1127
- outputs=[speed_display, send_btn, stop_btn]
1128
- )
1129
-
1130
- clear_btn.click(
1131
- clear_chat,
1132
- outputs=[chat_html, speed_display, send_btn, stop_btn]
1133
- )
1134
-
1135
- demo.launch(debug=True, share=True)
 
11
  from abc import ABC, abstractmethod
12
  import time
13
  import threading
14
+ import hashlib
15
+ import sqlite3
16
+ from datetime import datetime, timedelta
17
+ import pytz
18
 
19
  # ==============================================================================
20
  # Performance Optimizations for CPU
21
  # ==============================================================================
 
22
  tf.config.threading.set_inter_op_parallelism_threads(1)
23
  tf.config.threading.set_intra_op_parallelism_threads(2)
 
 
24
  tf.config.optimizer.set_jit(True)
 
 
25
  tf.config.run_functions_eagerly(False)
 
 
26
  os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'
27
  os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
28
 
29
+ # Australian timezone
30
+ AUSTRALIA_TZ = pytz.timezone('Australia/Sydney')
31
+
32
+ # ==============================================================================
33
+ # Database Setup
34
+ # ==============================================================================
35
+ def init_database():
36
+ """Initialize SQLite database for users and subscriptions."""
37
+ conn = sqlite3.connect('sam_users.db', check_same_thread=False)
38
+ c = conn.cursor()
39
+
40
+ # Users table
41
+ c.execute('''CREATE TABLE IF NOT EXISTS users
42
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
43
+ username TEXT UNIQUE NOT NULL,
44
+ password_hash TEXT NOT NULL,
45
+ email TEXT,
46
+ plan TEXT DEFAULT 'free',
47
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
48
+ is_admin BOOLEAN DEFAULT 0,
49
+ rate_limit_start TIMESTAMP,
50
+ messages_used_nano INTEGER DEFAULT 0,
51
+ messages_used_mini INTEGER DEFAULT 0,
52
+ messages_used_fast INTEGER DEFAULT 0,
53
+ messages_used_large INTEGER DEFAULT 0)''')
54
+
55
+ # Upgrade requests table
56
+ c.execute('''CREATE TABLE IF NOT EXISTS upgrade_requests
57
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
58
+ user_id INTEGER,
59
+ requested_plan TEXT,
60
+ reason TEXT,
61
+ status TEXT DEFAULT 'pending',
62
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
63
+ FOREIGN KEY (user_id) REFERENCES users(id))''')
64
+
65
+ # Usage tracking
66
+ c.execute('''CREATE TABLE IF NOT EXISTS usage_logs
67
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
68
+ user_id INTEGER,
69
+ tokens_used INTEGER,
70
+ model_used TEXT,
71
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
72
+ FOREIGN KEY (user_id) REFERENCES users(id))''')
73
+
74
+ # Create admin account if not exists
75
+ admin_pass = hashlib.sha256("admin123".encode()).hexdigest()
76
+ try:
77
+ c.execute("INSERT INTO users (username, password_hash, email, plan, is_admin) VALUES (?, ?, ?, ?, ?)",
78
+ ("admin", admin_pass, "[email protected]", "pro", 1))
79
+ conn.commit()
80
+ print("✅ Admin account created (username: admin, password: admin123)")
81
+ except sqlite3.IntegrityError:
82
+ print("✅ Admin account already exists")
83
+
84
+ conn.commit()
85
+ return conn
86
+
87
+ # Global database connection
88
+ db_conn = init_database()
89
+ db_lock = threading.Lock()
90
+
91
+ # Plan limits with 3-hour rolling window
92
+ PLAN_LIMITS = {
93
+ 'free': {
94
+ 'nano_messages': -1,
95
+ 'mini_messages': -1,
96
+ 'fast_messages': 10,
97
+ 'large_messages': 8,
98
+ 'can_choose_model': False,
99
+ 'max_tokens': 256,
100
+ 'reset_hours': 3
101
+ },
102
+ 'plus': {
103
+ 'nano_messages': -1,
104
+ 'mini_messages': -1,
105
+ 'fast_messages': -1,
106
+ 'large_messages': 20,
107
+ 'can_choose_model': True,
108
+ 'max_tokens': 384,
109
+ 'reset_hours': 3
110
+ },
111
+ 'pro': {
112
+ 'nano_messages': -1,
113
+ 'mini_messages': -1,
114
+ 'fast_messages': -1,
115
+ 'large_messages': -1,
116
+ 'can_choose_model': True,
117
+ 'max_tokens': 512,
118
+ 'reset_hours': 3
119
+ }
120
+ }
121
+
122
+ def get_model_type(model_name):
123
+ """Get model type from model name."""
124
+ if 'Nano' in model_name:
125
+ return 'nano'
126
+ elif 'Mini' in model_name:
127
+ return 'mini'
128
+ elif 'Fast' in model_name:
129
+ return 'fast'
130
+ elif 'Large' in model_name:
131
+ return 'large'
132
+ return 'nano'
133
+
134
  # ==============================================================================
135
+ # User Management Functions
136
+ # ==============================================================================
137
+ def hash_password(password):
138
+ return hashlib.sha256(password.encode()).hexdigest()
139
+
140
+ def create_user(username, password, email=""):
141
+ with db_lock:
142
+ try:
143
+ c = db_conn.cursor()
144
+ now = datetime.now(AUSTRALIA_TZ).isoformat()
145
+ c.execute("INSERT INTO users (username, password_hash, email, rate_limit_start) VALUES (?, ?, ?, ?)",
146
+ (username, hash_password(password), email, now))
147
+ db_conn.commit()
148
+ return True, "Account created successfully!"
149
+ except sqlite3.IntegrityError:
150
+ return False, "Username already exists!"
151
+
152
+ def authenticate_user(username, password):
153
+ with db_lock:
154
+ c = db_conn.cursor()
155
+ c.execute("SELECT id, password_hash, plan, is_admin FROM users WHERE username = ?", (username,))
156
+ result = c.fetchone()
157
+
158
+ if result and result[1] == hash_password(password):
159
+ return True, {"id": result[0], "username": username, "plan": result[2], "is_admin": bool(result[3])}
160
+ return False, None
161
+
162
+ def check_and_reset_limits(user_id):
163
+ """Check if 3-hour window has passed and reset limits if needed."""
164
+ with db_lock:
165
+ c = db_conn.cursor()
166
+ c.execute("SELECT rate_limit_start, plan FROM users WHERE id = ?", (user_id,))
167
+ result = c.fetchone()
168
+
169
+ if not result:
170
+ return
171
+
172
+ rate_limit_start_str, plan = result
173
+ reset_hours = PLAN_LIMITS[plan]['reset_hours']
174
+
175
+ if rate_limit_start_str:
176
+ rate_limit_start = datetime.fromisoformat(rate_limit_start_str)
177
+ now = datetime.now(AUSTRALIA_TZ)
178
+
179
+ if now - rate_limit_start >= timedelta(hours=reset_hours):
180
+ new_start = now.isoformat()
181
+ c.execute("""UPDATE users
182
+ SET rate_limit_start = ?,
183
+ messages_used_nano = 0,
184
+ messages_used_mini = 0,
185
+ messages_used_fast = 0,
186
+ messages_used_large = 0
187
+ WHERE id = ?""", (new_start, user_id))
188
+ db_conn.commit()
189
+
190
+ def get_user_limits_info(user_id):
191
+ """Get user's current usage and limits with reset time."""
192
+ check_and_reset_limits(user_id)
193
+
194
+ with db_lock:
195
+ c = db_conn.cursor()
196
+ c.execute("""SELECT plan, rate_limit_start,
197
+ messages_used_nano, messages_used_mini,
198
+ messages_used_fast, messages_used_large
199
+ FROM users WHERE id = ?""", (user_id,))
200
+ result = c.fetchone()
201
+
202
+ if not result:
203
+ return None
204
+
205
+ plan, rate_limit_start_str, nano_used, mini_used, fast_used, large_used = result
206
+ limits = PLAN_LIMITS[plan]
207
+
208
+ if rate_limit_start_str:
209
+ rate_limit_start = datetime.fromisoformat(rate_limit_start_str)
210
+ reset_time = rate_limit_start + timedelta(hours=limits['reset_hours'])
211
+ now = datetime.now(AUSTRALIA_TZ)
212
+ time_until_reset = reset_time - now
213
+
214
+ hours, remainder = divmod(int(time_until_reset.total_seconds()), 3600)
215
+ minutes, seconds = divmod(remainder, 60)
216
+ reset_str = f"{hours}h {minutes}m"
217
+ else:
218
+ reset_str = "N/A"
219
+
220
+ return {
221
+ 'plan': plan,
222
+ 'nano_used': nano_used,
223
+ 'mini_used': mini_used,
224
+ 'fast_used': fast_used,
225
+ 'large_used': large_used,
226
+ 'nano_limit': limits['nano_messages'],
227
+ 'mini_limit': limits['mini_messages'],
228
+ 'fast_limit': limits['fast_messages'],
229
+ 'large_limit': limits['large_messages'],
230
+ 'can_choose_model': limits['can_choose_model'],
231
+ 'max_tokens': limits['max_tokens'],
232
+ 'reset_in': reset_str
233
+ }
234
+
235
+ def can_use_model(user_id, model_name):
236
+ """Check if user can use a specific model."""
237
+ info = get_user_limits_info(user_id)
238
+ if not info:
239
+ return False, "User not found"
240
+
241
+ model_type = get_model_type(model_name)
242
+ used_key = f"{model_type}_used"
243
+ limit_key = f"{model_type}_limit"
244
+
245
+ used = info[used_key]
246
+ limit = info[limit_key]
247
+
248
+ if limit == -1:
249
+ return True, "OK"
250
+
251
+ if used >= limit:
252
+ return False, f"Limit reached for {model_type.upper()} model ({used}/{limit}). Resets in {info['reset_in']}"
253
+
254
+ return True, "OK"
255
+
256
+ def increment_model_usage(user_id, model_name):
257
+ """Increment usage counter for a model."""
258
+ model_type = get_model_type(model_name)
259
+ column = f"messages_used_{model_type}"
260
+
261
+ with db_lock:
262
+ c = db_conn.cursor()
263
+ c.execute(f"UPDATE users SET {column} = {column} + 1 WHERE id = ?", (user_id,))
264
+ db_conn.commit()
265
+
266
+ def get_available_models_for_user(user_id):
267
+ """Get list of models user can currently use."""
268
+ info = get_user_limits_info(user_id)
269
+ if not info:
270
+ return []
271
+
272
+ available = []
273
+
274
+ for model_type in ['nano', 'mini', 'fast', 'large']:
275
+ used = info[f'{model_type}_used']
276
+ limit = info[f'{model_type}_limit']
277
+
278
+ if limit == -1 or used < limit:
279
+ for model_name in available_models.keys():
280
+ if get_model_type(model_name) == model_type:
281
+ available.append(model_name)
282
+ break
283
+
284
+ return available
285
+
286
+ def log_usage(user_id, tokens, model):
287
+ with db_lock:
288
+ c = db_conn.cursor()
289
+ c.execute("INSERT INTO usage_logs (user_id, tokens_used, model_used) VALUES (?, ?, ?)",
290
+ (user_id, tokens, model))
291
+ db_conn.commit()
292
+
293
+ def request_upgrade(user_id, plan, reason):
294
+ with db_lock:
295
+ try:
296
+ c = db_conn.cursor()
297
+ c.execute("INSERT INTO upgrade_requests (user_id, requested_plan, reason) VALUES (?, ?, ?)",
298
+ (user_id, plan, reason))
299
+ db_conn.commit()
300
+ return True, "Upgrade request submitted! Admin will review soon."
301
+ except Exception as e:
302
+ return False, f"Error: {str(e)}"
303
+
304
+ def get_all_users():
305
+ with db_lock:
306
+ c = db_conn.cursor()
307
+ c.execute("""SELECT id, username, email, plan, created_at, is_admin,
308
+ messages_used_nano, messages_used_mini,
309
+ messages_used_fast, messages_used_large,
310
+ rate_limit_start
311
+ FROM users ORDER BY created_at DESC""")
312
+ return c.fetchall()
313
+
314
+ def get_pending_requests():
315
+ with db_lock:
316
+ c = db_conn.cursor()
317
+ c.execute("""SELECT r.id, u.username, r.requested_plan, r.reason, r.created_at
318
+ FROM upgrade_requests r
319
+ JOIN users u ON r.user_id = u.id
320
+ WHERE r.status = 'pending'
321
+ ORDER BY r.created_at DESC""")
322
+ return c.fetchall()
323
+
324
+ def update_user_plan(username, new_plan):
325
+ with db_lock:
326
+ try:
327
+ c = db_conn.cursor()
328
+ now = datetime.now(AUSTRALIA_TZ).isoformat()
329
+ c.execute("""UPDATE users
330
+ SET plan = ?,
331
+ rate_limit_start = ?,
332
+ messages_used_nano = 0,
333
+ messages_used_mini = 0,
334
+ messages_used_fast = 0,
335
+ messages_used_large = 0
336
+ WHERE username = ?""", (new_plan, now, username))
337
+ db_conn.commit()
338
+ return True, f"User {username} upgraded to {new_plan}!"
339
+ except Exception as e:
340
+ return False, f"Error: {str(e)}"
341
+
342
+ def approve_request(request_id):
343
+ with db_lock:
344
+ try:
345
+ c = db_conn.cursor()
346
+ c.execute("SELECT user_id, requested_plan FROM upgrade_requests WHERE id = ?", (request_id,))
347
+ result = c.fetchone()
348
+
349
+ if result:
350
+ user_id, plan = result
351
+ now = datetime.now(AUSTRALIA_TZ).isoformat()
352
+ c.execute("""UPDATE users
353
+ SET plan = ?,
354
+ rate_limit_start = ?,
355
+ messages_used_nano = 0,
356
+ messages_used_mini = 0,
357
+ messages_used_fast = 0,
358
+ messages_used_large = 0
359
+ WHERE id = ?""", (plan, now, user_id))
360
+ c.execute("UPDATE upgrade_requests SET status = 'approved' WHERE id = ?", (request_id,))
361
+ db_conn.commit()
362
+ return True, "Request approved!"
363
+ return False, "Request not found"
364
+ except Exception as e:
365
+ return False, f"Error: {str(e)}"
366
+
367
+ def deny_request(request_id):
368
+ with db_lock:
369
+ try:
370
+ c = db_conn.cursor()
371
+ c.execute("UPDATE upgrade_requests SET status = 'denied' WHERE id = ?", (request_id,))
372
+ db_conn.commit()
373
+ return True, "Request denied"
374
+ except Exception as e:
375
+ return False, f"Error: {str(e)}"
376
+
377
+ # ==============================================================================
378
+ # Model Architecture
379
  # ==============================================================================
380
  @keras.saving.register_keras_serializable()
381
  class RotaryEmbedding(keras.layers.Layer):
 
392
  t = tf.range(self.max_len, dtype=tf.float32)
393
  freqs = tf.einsum("i,j->ij", t, inv_freq)
394
  emb = tf.concat([freqs, freqs], axis=-1)
 
395
  self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
396
  self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
397
  self.built_cache = True
 
406
  dtype = q.dtype
407
  cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
408
  sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
 
409
  q_rotated = (q * cos) + (self.rotate_half(q) * sin)
410
  k_rotated = (k * cos) + (self.rotate_half(k) * sin)
 
411
  return q_rotated, k_rotated
412
 
413
  def get_config(self):
 
415
  config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
416
  return config
417
 
 
418
  @keras.saving.register_keras_serializable()
419
  class RMSNorm(keras.layers.Layer):
420
  def __init__(self, epsilon=1e-5, **kwargs):
 
433
  config.update({"epsilon": self.epsilon})
434
  return config
435
 
 
436
  @keras.saving.register_keras_serializable()
437
  class TransformerBlock(keras.layers.Layer):
438
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
 
445
  self.rope_theta = rope_theta
446
  self.head_dim = d_model // n_heads
447
  self.layer_idx = layer_idx
 
448
  self.pre_attn_norm = RMSNorm()
449
  self.pre_ffn_norm = RMSNorm()
 
450
  self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
451
  self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
452
  self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
453
  self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
 
454
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
 
455
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
456
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
457
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
 
458
  self.dropout = keras.layers.Dropout(dropout)
459
 
460
  def call(self, x, training=None):
461
  B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
462
  dtype = x.dtype
 
463
  res = x
464
  y = self.pre_attn_norm(x)
 
465
  q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
466
  k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
467
  v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
 
468
  q, k = self.rope(q, k)
 
469
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
470
+ mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
 
 
 
 
 
471
  scores += mask
472
  attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
 
473
  attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
474
  x = res + self.dropout(self.out_proj(attn), training=training)
 
475
  res = x
476
  y = self.pre_ffn_norm(x)
477
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
 
478
  return res + self.dropout(ffn, training=training)
479
 
480
  def get_config(self):
481
  config = super().get_config()
482
+ config.update({"d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta":
483
+
 
 
 
 
 
 
 
 
484
 
485
+ # PART 2 - Continue from Part 1
486
+
487
+ self.rope_theta, "layer_idx": self.layer_idx})
488
+ return config
489
 
490
  @keras.saving.register_keras_serializable()
491
  class SAM1Model(keras.Model):
 
497
  self.cfg = kwargs
498
  else:
499
  self.cfg = kwargs.get('cfg', kwargs)
 
500
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
 
501
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
502
+ block_args = {'d_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']}
 
 
 
 
 
 
 
 
503
  self.blocks = []
504
  for i in range(self.cfg['n_layers']):
505
  block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
506
  self.blocks.append(block)
 
507
  self.norm = RMSNorm(name="final_norm")
508
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
509
 
510
  def call(self, input_ids, training=None):
511
  x = self.embed(input_ids)
 
512
  for block in self.blocks:
513
  x = block(x, training=training)
 
514
  return self.lm_head(self.norm(x))
515
 
516
  def get_config(self):
 
518
  base_config['config'] = self.cfg
519
  return base_config
520
 
 
 
 
 
521
  def count_parameters(model):
 
522
  total_params = 0
523
  non_zero_params = 0
 
524
  for weight in model.weights:
525
  w = weight.numpy()
526
  total_params += w.size
527
  non_zero_params += np.count_nonzero(w)
 
528
  return total_params, non_zero_params
529
 
 
530
  def format_param_count(count):
 
531
  if count >= 1e9:
532
  return f"{count/1e9:.2f}B"
533
  elif count >= 1e6:
 
537
  else:
538
  return str(count)
539
 
 
 
 
 
540
  class ModelBackend(ABC):
541
  @abstractmethod
542
  def predict(self, input_ids):
543
  pass
 
544
  @abstractmethod
545
  def get_name(self):
546
  pass
 
547
  @abstractmethod
548
  def get_info(self):
549
  pass
550
 
 
551
  class KerasBackend(ModelBackend):
552
  def __init__(self, model, name, display_name):
553
  self.model = model
554
  self.name = name
555
  self.display_name = display_name
556
+ @tf.function(input_signature=[tf.TensorSpec(shape=[1, None], dtype=tf.int32)], jit_compile=True)
 
 
 
 
 
557
  def fast_predict(inputs):
558
  return model(inputs, training=False)
 
559
  self.fast_predict = fast_predict
 
 
560
  print(f" 🔥 Warming up {display_name}...")
561
  dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
562
  _ = self.fast_predict(dummy)
563
  print(f" ✅ Compilation complete!")
 
 
564
  total, non_zero = count_parameters(model)
565
  self.total_params = total
566
  self.non_zero_params = non_zero
567
  self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0
 
 
568
  self.n_heads = model.cfg.get('n_heads', 0)
569
  self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
570
 
 
585
  info += f" Sparsity: {self.sparsity:.1f}%\n"
586
  return info
587
 
 
 
 
 
588
  MODEL_REGISTRY = [
589
  ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
590
  ("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast.weights.h5", "sam1_fast_config.json"),
 
592
  ("SAM-X-1-Nano ⚡⚡", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano_finetuned.weights.h5", "sam1_nano_finetuned_config.json"),
593
  ]
594
 
 
 
 
 
 
 
 
 
595
  def estimate_prompt_complexity(prompt):
 
596
  prompt_lower = prompt.lower()
 
 
597
  complexity_score = 0
 
 
598
  word_count = len(prompt.split())
599
  if word_count > 100:
600
  complexity_score += 3
 
602
  complexity_score += 2
603
  elif word_count > 20:
604
  complexity_score += 1
605
+ hard_keywords = ['analyze', 'explain', 'compare', 'evaluate', 'prove', 'derive', 'calculate', 'solve', 'reason', 'why', 'how does', 'complex', 'algorithm', 'mathematics', 'philosophy', 'theory', 'logic', 'detailed', 'comprehensive', 'thorough', 'in-depth']
 
 
 
 
 
 
 
606
  for keyword in hard_keywords:
607
  if keyword in prompt_lower:
608
  complexity_score += 2
609
+ medium_keywords = ['write', 'create', 'generate', 'summarize', 'describe', 'list', 'what is', 'tell me', 'explain briefly']
 
 
 
 
 
610
  for keyword in medium_keywords:
611
  if keyword in prompt_lower:
612
  complexity_score += 1
 
 
613
  if any(word in prompt_lower for word in ['code', 'function', 'program', 'debug', 'implement']):
614
  complexity_score += 2
 
 
615
  if any(word in prompt_lower for word in ['first', 'then', 'next', 'finally', 'step']):
616
  complexity_score += 1
 
 
617
  question_marks = prompt.count('?')
618
  if question_marks > 1:
619
  complexity_score += 1
 
620
  return complexity_score
621
 
622
+ def select_model_auto(prompt, available_models_dict, user_available_models):
 
623
  complexity = estimate_prompt_complexity(prompt)
624
+ accessible = {k: v for k, v in available_models_dict.items() if k in user_available_models}
625
+ if not accessible:
626
+ return None
 
 
 
 
627
  if complexity <= 2:
628
  preferred = "SAM-X-1-Nano ⚡⚡"
629
  fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"]
 
636
  else:
637
  preferred = "SAM-X-1-Large"
638
  fallback_order = ["SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Nano ⚡⚡"]
639
+ if preferred in accessible:
640
+ return accessible[preferred]
 
 
 
 
 
641
  for model_name in fallback_order:
642
+ if model_name in accessible:
643
+ return accessible[model_name]
644
+ return list(accessible.values())[0]
 
 
 
645
 
 
 
 
646
  CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
 
647
  print("="*80)
648
  print("🤖 SAM-X-1 Multi-Model Chat Interface".center(80))
649
  print("="*80)
 
 
650
  print(f"\n📦 Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}")
651
  config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
652
  tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json")
 
 
653
  with open(config_path, 'r') as f:
654
  base_config = json.load(f)
 
655
  print(f"✅ Base config loaded")
656
+ base_model_config = {'vocab_size': base_config['vocab_size'], 'd_model': base_config['hidden_size'], 'n_heads': base_config['num_attention_heads'], 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'], 'dropout': base_config.get('dropout', 0.0), 'max_len': base_config['max_position_embeddings'], 'rope_theta': base_config['rope_theta'], 'n_layers': base_config['num_hidden_layers']}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  print("\n🔤 Recreating tokenizer...")
658
  tokenizer = Tokenizer.from_pretrained("gpt2")
 
 
659
  eos_token = "<|endoftext|>"
660
  eos_token_id = tokenizer.token_to_id(eos_token)
 
661
  if eos_token_id is None:
 
662
  tokenizer.add_special_tokens([eos_token])
663
  eos_token_id = tokenizer.token_to_id(eos_token)
 
 
664
  custom_tokens = ["<think>", "<think/>"]
665
  for token in custom_tokens:
666
  if tokenizer.token_to_id(token) is None:
667
  tokenizer.add_special_tokens([token])
 
668
  tokenizer.no_padding()
669
  tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
 
670
  print(f"✅ Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})")
671
  print(f" EOS token: '{eos_token}' (ID: {eos_token_id})")
 
 
672
  if eos_token_id is None:
673
+ raise ValueError("❌ Failed to set EOS token ID!")
 
 
674
  print("\n" + "="*80)
675
  print("📦 LOADING MODELS".center(80))
676
  print("="*80)
 
677
  available_models = {}
678
  dummy_input = tf.zeros((1, 1), dtype=tf.int32)
 
679
  for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
680
  try:
681
  print(f"\n⏳ Loading: {display_name}")
682
  print(f" Repo: {repo_id}")
683
  print(f" Weights: {weights_filename}")
 
 
684
  weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
 
 
685
  if config_filename:
686
  print(f" Config: {config_filename}")
687
  custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
688
  with open(custom_config_path, 'r') as f:
689
  model_config = json.load(f)
690
+ print(f" 📐 Custom architecture: {model_config['n_heads']} heads")
691
  else:
692
  model_config = base_model_config.copy()
 
 
693
  model = SAM1Model(**model_config)
694
  model(dummy_input)
695
  model.load_weights(weights_path)
696
  model.trainable = False
 
 
697
  backend = KerasBackend(model, display_name, display_name)
698
  available_models[display_name] = backend
 
 
699
  print(f" ✅ Loaded successfully!")
700
  print(f" 📊 Parameters: {format_param_count(backend.total_params)}")
 
 
 
701
  except Exception as e:
702
  print(f" ⚠️ Failed to load: {e}")
 
 
703
  if not available_models:
704
+ raise RuntimeError("❌ No models loaded!")
 
705
  print(f"\n✅ Successfully loaded {len(available_models)} model(s)")
 
 
706
  current_backend = list(available_models.values())[0]
 
 
707
  stop_generation = threading.Event()
708
 
 
 
 
 
709
  def generate_response_stream(prompt, temperature=0.7, backend=None, max_tokens=256):
 
710
  global stop_generation
711
  stop_generation.clear()
 
712
  if backend is None:
713
  backend = current_backend
 
 
714
  encoded_prompt = tokenizer.encode(prompt)
715
  input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
716
  generated = input_ids.copy()
 
717
  current_text = ""
718
  in_thinking = False
 
 
719
  max_len = backend.model.cfg['max_len']
 
 
720
  start_time = time.time()
721
  tokens_generated = 0
 
 
722
  decode_buffer = []
723
+ decode_every = 2
724
  last_speed_check = start_time
 
 
725
  for step in range(max_tokens):
 
726
  if stop_generation.is_set():
 
 
727
  elapsed = time.time() - start_time
728
  final_speed = tokens_generated / elapsed if elapsed > 0 else 0
729
+ yield "", False, -1, final_speed, True
730
  return
 
731
  current_input = generated[-max_len:]
 
 
732
  next_token_logits = backend.predict(current_input)
 
 
 
733
  if tokens_generated > 5 and tokens_generated % 10 == 0:
734
  current_time = time.time()
735
  elapsed_since_check = current_time - last_speed_check
736
  if elapsed_since_check > 0:
737
  recent_speed = 10 / elapsed_since_check
 
738
  if recent_speed > 25:
739
+ decode_every = 8
740
  elif recent_speed > 15:
741
+ decode_every = 5
742
  elif recent_speed > 8:
743
+ decode_every = 3
744
  else:
745
+ decode_every = 2
746
  last_speed_check = current_time
 
747
  if temperature > 0:
748
  next_token_logits = next_token_logits / temperature
749
  top_k = 5
 
755
  next_token = top_k_indices[np.random.choice(top_k, p=probs)]
756
  else:
757
  next_token = np.argmax(next_token_logits)
 
 
758
  if next_token == eos_token_id:
 
759
  break
 
760
  generated.append(int(next_token))
761
  decode_buffer.append(int(next_token))
762
  tokens_generated += 1
763
+ should_decode = (len(decode_buffer) >= decode_every or step == max_tokens - 1)
 
 
 
 
764
  if should_decode:
765
  new_text = tokenizer.decode(generated[len(input_ids):])
766
  if len(new_text) > len(current_text):
767
  new_chunk = new_text[len(current_text):]
768
  current_text = new_text
 
769
  if "<think>" in new_chunk:
770
  in_thinking = True
771
  elif "</think>" in new_chunk or "<think/>" in new_chunk:
772
  in_thinking = False
 
 
773
  elapsed = time.time() - start_time
774
  tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0
 
775
  yield new_chunk, in_thinking, tokens_per_sec, tokens_per_sec, False
776
  decode_buffer = []
 
 
777
  elapsed = time.time() - start_time
778
  final_tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0
779
  yield "", False, final_tokens_per_sec, final_tokens_per_sec, False
780