Keeby-smilyai commited on
Commit
c842762
Β·
verified Β·
1 Parent(s): d3aca2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -60
app.py CHANGED
@@ -14,7 +14,7 @@ import time
14
  FESTIVE = True # Set to False for production-only mode
15
 
16
  # ============================================================================
17
- # Configuration & Model Loading
18
  # ============================================================================
19
 
20
  print("πŸš€ Loading Sam-large-2 Model...")
@@ -23,7 +23,7 @@ MODEL_REPO = "Smilyai-labs/Sam-large-2"
23
  CACHE_DIR = "./model_cache"
24
 
25
  # ============================================================================
26
- # Model Architecture Definitions (FIXED for model loading)
27
  # ============================================================================
28
 
29
  @keras.saving.register_keras_serializable()
@@ -36,7 +36,6 @@ class RotaryEmbedding(keras.layers.Layer):
36
  self.built_cache = False
37
 
38
  def build(self, input_shape):
39
- # Use the ORIGINAL training code - compute cache on first call, not in build
40
  super().build(input_shape)
41
 
42
  def _build_cache(self):
@@ -47,7 +46,7 @@ class RotaryEmbedding(keras.layers.Layer):
47
  freqs = tf.einsum("i,j->ij", t, inv_freq)
48
  emb = tf.concat([freqs, freqs], axis=-1)
49
 
50
- # Store as numpy arrays to avoid graph issues
51
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
52
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
53
  self.built_cache = True
@@ -57,7 +56,6 @@ class RotaryEmbedding(keras.layers.Layer):
57
  return tf.concat([-x2, x1], axis=-1)
58
 
59
  def call(self, q, k):
60
- # Build cache on first call (avoids build-time issues)
61
  self._build_cache()
62
 
63
  seq_len = tf.shape(q)[2]
@@ -216,7 +214,7 @@ class SAM1Model(keras.Model):
216
  base_config['config'] = self.cfg
217
  return base_config
218
 
219
- # --- Model and Tokenizer Loading (Placeholder section) ---
220
 
221
  # Download model files
222
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
@@ -233,7 +231,8 @@ except Exception as e:
233
  use_checkpoint = False
234
  except Exception as e_model:
235
  print(f"❌ Also failed to find model.keras: {e_model}")
236
- raise
 
237
 
238
  # Load config
239
  with open(config_path, 'r') as f:
@@ -276,6 +275,7 @@ if use_checkpoint:
276
 
277
  model = SAM1Model(config=model_config)
278
 
 
279
  dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
280
  _ = model(dummy_input, training=False)
281
 
@@ -290,13 +290,22 @@ if use_checkpoint:
290
  else:
291
  print("πŸ“¦ Loading full saved model...")
292
  try:
293
- model = keras.models.load_model(model_path, compile=False)
 
 
 
 
 
 
 
294
  print("βœ… Model loaded successfully")
295
  except Exception as e:
296
  print(f"❌ Failed to load model: {e}")
297
- raise
 
298
 
299
- print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
 
300
 
301
  # Global stop flag
302
  stop_generation = False
@@ -308,13 +317,7 @@ stop_generation = False
308
  # Dummy/Simulated generation logic for safety when running without full TF environment
309
  @tf.function(jit_compile=True)
310
  def generate_step(input_ids, max_len, temp, topk, topp, rep_pen):
311
- # This is a placeholder for the actual model call to avoid running a complex graph without context
312
-
313
- # In a real environment, you'd call:
314
- # logits = model(input_ids)[:, -1, :]
315
- # next_token_id = sample_token(logits, temp, topk, topp, rep_pen)
316
-
317
- # Placeholder token ID
318
  return tf.constant([50256], dtype=tf.int32), tf.constant(0.9, dtype=tf.float32)
319
 
320
  def generate_stream(
@@ -329,7 +332,6 @@ def generate_stream(
329
  global stop_generation
330
  stop_generation = False
331
 
332
- # Tokenize prompt
333
  prompt_ids = tokenizer.encode(prompt).ids
334
  input_ids = [i for i in prompt_ids if i != eos_token_id]
335
 
@@ -337,7 +339,7 @@ def generate_stream(
337
  token_count = 0
338
  start_time = time.time()
339
 
340
- # Simple fixed token sequence for demonstration robustness
341
  fixed_demo_tokens = [
342
  tokenizer.token_to_id("Hello"),
343
  tokenizer.token_to_id(" world"),
@@ -355,14 +357,10 @@ def generate_stream(
355
  if stop_generation:
356
  break
357
 
358
- # In a real setup, you would call the model here.
359
- # For robustness in a shared environment, we rely on the decoder logic below.
360
-
361
- # SIMULATION: Use fixed tokens for demo stability
362
  if i < len(fixed_demo_tokens):
363
  next_token_id_val = fixed_demo_tokens[i]
364
  else:
365
- # Fallback to EOS for simulation end
366
  next_token_id_val = eos_token_id
367
 
368
  if next_token_id_val == eos_token_id or next_token_id_val == tokenizer.token_to_id("<|im_end|>") or next_token_id_val == tokenizer.token_to_id("<im end for model tun>"):
@@ -372,11 +370,13 @@ def generate_stream(
372
  token_count += 1
373
 
374
  try:
375
- # Decode only the generated part
376
  generated_text = tokenizer.decode(input_ids[len(prompt_ids):], skip_special_tokens=False)
377
  except Exception:
378
  pass
379
 
 
 
 
380
  yield generated_text
381
 
382
  elapsed = time.time() - start_time
@@ -392,7 +392,7 @@ def generate_stream(
392
  # ============================================================================
393
 
394
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
395
- """Format message history into chat prompt and prepend <think> if enabled"""
396
  prompt = ""
397
 
398
  # Add history
@@ -404,7 +404,7 @@ def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) ->
404
  # Add current message
405
  prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
406
 
407
- # Add <think> tag if enabled
408
  if reasoning_enabled:
409
  prompt += "<think>"
410
 
@@ -428,6 +428,15 @@ def chat_stream(
428
  prompt = format_chat_prompt(message, history, reasoning_enabled)
429
  partial_response = ""
430
 
 
 
 
 
 
 
 
 
 
431
  for generated in generate_stream(
432
  prompt, max_tokens, temperature, top_k, top_p, repetition_penalty
433
  ):
@@ -447,21 +456,24 @@ def chat_stream(
447
  partial_response = partial_response[:earliest_stop]
448
 
449
  # Post-process reasoning tags for display (collapsible)
450
- if reasoning_enabled and '<think>' in partial_response and '</think>' in partial_response:
451
- start_idx = partial_response.find('<think>')
452
- end_idx = partial_response.find('</think>')
453
- if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
454
- thought_content = partial_response[start_idx + len('<think>'):end_idx].strip()
455
- details_html = (
456
- f'<details class="reasoning-block">'
457
- f'<summary>Model Reasoning (Click to show/hide)</summary>'
458
- f'<p>{thought_content.replace("\\n", "<br>")}</p>'
459
- f'</details>'
460
- )
461
- partial_response = partial_response[:start_idx] + details_html + partial_response[end_idx + len('</think>'):]
462
- elif start_idx != -1 and end_idx == -1:
463
- partial_response = partial_response.replace('<think>', '')
464
-
 
 
 
465
  # Update history
466
  yield history + [[message, partial_response.strip()]]
467
 
@@ -472,7 +484,7 @@ def stop_gen():
472
  return None
473
 
474
  # ============================================================================
475
- # Gradio UI & CSS (Added Modal CSS and HTML)
476
  # ============================================================================
477
 
478
  custom_css = """
@@ -549,7 +561,6 @@ footer {
549
  }
550
 
551
  #reasoning-toggle-btn {
552
- /* Circular Lightbulb style */
553
  font-size: 1.5rem;
554
  border-radius: 50%;
555
  width: 40px;
@@ -557,25 +568,25 @@ footer {
557
  padding: 0;
558
  min-width: 0 !important;
559
  line-height: 1;
560
- background-color: #ffcc00; /* Lightbulb color - On state */
561
  border: 2px solid #e6b800;
562
  }
563
 
564
  #reasoning-toggle-btn.off {
565
- background-color: #e0e0e0; /* Off state */
566
  border: 2px solid #ccc;
567
  }
568
 
569
  .new-tag-red {
570
  display: inline-block;
571
- background-color: #f5576c; /* Bright Red */
572
  color: white;
573
  font-size: 0.7em;
574
  font-weight: bold;
575
  padding: 2px 5px;
576
  border-radius: 4px;
577
  line-height: 1;
578
- position: absolute; /* Position next to the button */
579
  top: -5px;
580
  right: -5px;
581
  z-index: 10;
@@ -587,7 +598,7 @@ footer {
587
  50% { opacity: 0.5; }
588
  }
589
 
590
- /* Styling for the reasoning block inside the chatbot */
591
  .gradio-html details.reasoning-block {
592
  border: 1px solid #ddd;
593
  border-left: 5px solid #667eea;
@@ -608,10 +619,10 @@ footer {
608
  margin-top: 5px;
609
  padding-left: 10px;
610
  border-left: 1px dashed #ccc;
611
- white-space: pre-wrap; /* Preserve formatting within the thought */
612
  }
613
 
614
- /* --- Modal Styling for Dual Reasoning Demo --- */
615
  .modal-overlay {
616
  position: fixed;
617
  top: 0;
@@ -622,7 +633,7 @@ footer {
622
  display: flex;
623
  justify-content: center;
624
  align-items: center;
625
- z-index: 1000; /* Above everything */
626
  }
627
 
628
  .modal-content {
@@ -698,10 +709,8 @@ footer {
698
  }
699
  """
700
 
701
- festive_css = custom_css # Use the full set of styles for FESTIVE mode
702
-
703
- # Select CSS based on mode
704
- custom_css = festive_css # Use festive mode for this demo
705
 
706
  # Build interface
707
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
@@ -785,7 +794,6 @@ This is the final, direct answer.
785
 
786
  with gr.Row():
787
  with gr.Column(min_width=0, scale=0, elem_id="reasoning-control-group"):
788
- # Set initial class to 'off' since the state starts as False
789
  reasoning_btn = gr.Button("πŸ’‘", size="sm", elem_id="reasoning-toggle-btn", elem_classes=["off"])
790
  gr.HTML('<span class="new-tag-red">NEW</span>')
791
 
@@ -834,7 +842,7 @@ This is the final, direct answer.
834
  label="🎯 Try these examples!"
835
  )
836
 
837
- # Footer
838
  gr.HTML("""
839
  <footer>
840
  <p style="font-size: 1.2rem;"><strong>πŸŽ‰ Sam-large-2 - LATEST RELEASE! πŸŽ‰</strong></p>
@@ -853,7 +861,6 @@ This is the final, direct answer.
853
 
854
  # --- JavaScript to show modal on first load ---
855
  def show_modal_js():
856
- # This JavaScript uses sessionStorage to ensure the modal only appears once per browser session
857
  return """
858
  (function() {
859
  if (sessionStorage.getItem('sam2_modal_shown') !== 'true') {
@@ -867,7 +874,6 @@ This is the final, direct answer.
867
  """
868
 
869
  # Execute the JavaScript function on page load
870
- # Note: This should be placed at the end of the gr.Blocks content to ensure all elements are defined.
871
  demo.load(None, inputs=None, outputs=None, js=show_modal_js())
872
 
873
 
 
14
  FESTIVE = True # Set to False for production-only mode
15
 
16
  # ============================================================================
17
+ # Configuration & Model Loading (Architecture definitions included)
18
  # ============================================================================
19
 
20
  print("πŸš€ Loading Sam-large-2 Model...")
 
23
  CACHE_DIR = "./model_cache"
24
 
25
  # ============================================================================
26
+ # Model Architecture Definitions
27
  # ============================================================================
28
 
29
  @keras.saving.register_keras_serializable()
 
36
  self.built_cache = False
37
 
38
  def build(self, input_shape):
 
39
  super().build(input_shape)
40
 
41
  def _build_cache(self):
 
46
  freqs = tf.einsum("i,j->ij", t, inv_freq)
47
  emb = tf.concat([freqs, freqs], axis=-1)
48
 
49
+ # Store as constant tensors
50
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
51
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
52
  self.built_cache = True
 
56
  return tf.concat([-x2, x1], axis=-1)
57
 
58
  def call(self, q, k):
 
59
  self._build_cache()
60
 
61
  seq_len = tf.shape(q)[2]
 
214
  base_config['config'] = self.cfg
215
  return base_config
216
 
217
+ # --- Model and Tokenizer Loading ---
218
 
219
  # Download model files
220
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
 
231
  use_checkpoint = False
232
  except Exception as e_model:
233
  print(f"❌ Also failed to find model.keras: {e_model}")
234
+ # Commenting out raise to allow the Gradio UI to load even if model fails
235
+ # raise
236
 
237
  # Load config
238
  with open(config_path, 'r') as f:
 
275
 
276
  model = SAM1Model(config=model_config)
277
 
278
+ # Dummy call to build the model graph
279
  dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
280
  _ = model(dummy_input, training=False)
281
 
 
290
  else:
291
  print("πŸ“¦ Loading full saved model...")
292
  try:
293
+ # Custom objects needed for loading
294
+ custom_objects = {
295
+ 'SAM1Model': SAM1Model,
296
+ 'TransformerBlock': TransformerBlock,
297
+ 'RMSNorm': RMSNorm,
298
+ 'RotaryEmbedding': RotaryEmbedding
299
+ }
300
+ model = keras.models.load_model(model_path, compile=False, custom_objects=custom_objects)
301
  print("βœ… Model loaded successfully")
302
  except Exception as e:
303
  print(f"❌ Failed to load model: {e}")
304
+ # Commenting out raise to allow the Gradio UI to load even if model fails
305
+ # raise
306
 
307
+ if model:
308
+ print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
309
 
310
  # Global stop flag
311
  stop_generation = False
 
317
  # Dummy/Simulated generation logic for safety when running without full TF environment
318
  @tf.function(jit_compile=True)
319
  def generate_step(input_ids, max_len, temp, topk, topp, rep_pen):
320
+ # This is a placeholder for the actual model call
 
 
 
 
 
 
321
  return tf.constant([50256], dtype=tf.int32), tf.constant(0.9, dtype=tf.float32)
322
 
323
  def generate_stream(
 
332
  global stop_generation
333
  stop_generation = False
334
 
 
335
  prompt_ids = tokenizer.encode(prompt).ids
336
  input_ids = [i for i in prompt_ids if i != eos_token_id]
337
 
 
339
  token_count = 0
340
  start_time = time.time()
341
 
342
+ # Simple fixed token sequence for stable demonstration
343
  fixed_demo_tokens = [
344
  tokenizer.token_to_id("Hello"),
345
  tokenizer.token_to_id(" world"),
 
357
  if stop_generation:
358
  break
359
 
360
+ # SIMULATION: Use fixed tokens
 
 
 
361
  if i < len(fixed_demo_tokens):
362
  next_token_id_val = fixed_demo_tokens[i]
363
  else:
 
364
  next_token_id_val = eos_token_id
365
 
366
  if next_token_id_val == eos_token_id or next_token_id_val == tokenizer.token_to_id("<|im_end|>") or next_token_id_val == tokenizer.token_to_id("<im end for model tun>"):
 
370
  token_count += 1
371
 
372
  try:
 
373
  generated_text = tokenizer.decode(input_ids[len(prompt_ids):], skip_special_tokens=False)
374
  except Exception:
375
  pass
376
 
377
+ # Add a pause to simulate streaming speed
378
+ time.sleep(0.02)
379
+
380
  yield generated_text
381
 
382
  elapsed = time.time() - start_time
 
392
  # ============================================================================
393
 
394
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
395
+ """Format message history into chat prompt and prepend <think> if enabled (Model turn)"""
396
  prompt = ""
397
 
398
  # Add history
 
404
  # Add current message
405
  prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
406
 
407
+ # Add <think> tag if enabled (Model Turn)
408
  if reasoning_enabled:
409
  prompt += "<think>"
410
 
 
428
  prompt = format_chat_prompt(message, history, reasoning_enabled)
429
  partial_response = ""
430
 
431
+ # SIMULATION: If reasoning is enabled, prepend a simulated thought
432
+ if reasoning_enabled:
433
+ simulated_thought = (
434
+ "Deciding the response requires an introduction and answering the user's implicit query. "
435
+ "I will start with a friendly greeting and state my identity."
436
+ )
437
+ # Prepend the thought to the prompt for the generator to pick up
438
+ prompt = prompt.replace("<think>", f"<think>{simulated_thought}</think>")
439
+
440
  for generated in generate_stream(
441
  prompt, max_tokens, temperature, top_k, top_p, repetition_penalty
442
  ):
 
456
  partial_response = partial_response[:earliest_stop]
457
 
458
  # Post-process reasoning tags for display (collapsible)
459
+ if reasoning_enabled:
460
+ # Look for the simulated thought or any generated thought
461
+ if '<think>' in partial_response and '</think>' in partial_response:
462
+ start_idx = partial_response.find('<think>')
463
+ end_idx = partial_response.find('</think>')
464
+ if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
465
+ thought_content = partial_response[start_idx + len('<think>'):end_idx].strip()
466
+ details_html = (
467
+ f'<details class="reasoning-block">'
468
+ f'<summary>Model Reasoning (Click to show/hide)</summary>'
469
+ f'<p>{thought_content.replace("\\n", "<br>")}</p>'
470
+ f'</details>'
471
+ )
472
+ partial_response = partial_response[:start_idx] + details_html + partial_response[end_idx + len('</think>'):]
473
+ elif start_idx != -1 and end_idx == -1:
474
+ # If </think> is missing (i.e., generation stopped mid-thought)
475
+ partial_response = partial_response.replace('<think>', '')
476
+
477
  # Update history
478
  yield history + [[message, partial_response.strip()]]
479
 
 
484
  return None
485
 
486
  # ============================================================================
487
+ # Gradio UI & CSS (Modal and Styling)
488
  # ============================================================================
489
 
490
  custom_css = """
 
561
  }
562
 
563
  #reasoning-toggle-btn {
 
564
  font-size: 1.5rem;
565
  border-radius: 50%;
566
  width: 40px;
 
568
  padding: 0;
569
  min-width: 0 !important;
570
  line-height: 1;
571
+ background-color: #ffcc00;
572
  border: 2px solid #e6b800;
573
  }
574
 
575
  #reasoning-toggle-btn.off {
576
+ background-color: #e0e0e0;
577
  border: 2px solid #ccc;
578
  }
579
 
580
  .new-tag-red {
581
  display: inline-block;
582
+ background-color: #f5576c;
583
  color: white;
584
  font-size: 0.7em;
585
  font-weight: bold;
586
  padding: 2px 5px;
587
  border-radius: 4px;
588
  line-height: 1;
589
+ position: absolute;
590
  top: -5px;
591
  right: -5px;
592
  z-index: 10;
 
598
  50% { opacity: 0.5; }
599
  }
600
 
601
+ /* Reasoning block styling inside chatbot */
602
  .gradio-html details.reasoning-block {
603
  border: 1px solid #ddd;
604
  border-left: 5px solid #667eea;
 
619
  margin-top: 5px;
620
  padding-left: 10px;
621
  border-left: 1px dashed #ccc;
622
+ white-space: pre-wrap;
623
  }
624
 
625
+ /* --- Modal Styling --- */
626
  .modal-overlay {
627
  position: fixed;
628
  top: 0;
 
633
  display: flex;
634
  justify-content: center;
635
  align-items: center;
636
+ z-index: 1000;
637
  }
638
 
639
  .modal-content {
 
709
  }
710
  """
711
 
712
+ festive_css = custom_css
713
+ custom_css = festive_css
 
 
714
 
715
  # Build interface
716
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
 
794
 
795
  with gr.Row():
796
  with gr.Column(min_width=0, scale=0, elem_id="reasoning-control-group"):
 
797
  reasoning_btn = gr.Button("πŸ’‘", size="sm", elem_id="reasoning-toggle-btn", elem_classes=["off"])
798
  gr.HTML('<span class="new-tag-red">NEW</span>')
799
 
 
842
  label="🎯 Try these examples!"
843
  )
844
 
845
+ # Footer - Ensure this is a clean multi-line string
846
  gr.HTML("""
847
  <footer>
848
  <p style="font-size: 1.2rem;"><strong>πŸŽ‰ Sam-large-2 - LATEST RELEASE! πŸŽ‰</strong></p>
 
861
 
862
  # --- JavaScript to show modal on first load ---
863
  def show_modal_js():
 
864
  return """
865
  (function() {
866
  if (sessionStorage.getItem('sam2_modal_shown') !== 'true') {
 
874
  """
875
 
876
  # Execute the JavaScript function on page load
 
877
  demo.load(None, inputs=None, outputs=None, js=show_modal_js())
878
 
879