xiezhe22 commited on
Commit
9166423
Β·
1 Parent(s): 171a80a
Files changed (2) hide show
  1. __pycache__/app.cpython-313.pyc +0 -0
  2. app.py +19 -24
__pycache__/app.cpython-313.pyc ADDED
Binary file (17.1 kB). View file
 
app.py CHANGED
@@ -13,20 +13,21 @@ from transformers import (
13
 
14
  # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
15
  # Default to 8B but keep both variants resident on the GPU.
16
- DEFAULT_MODEL_NAME = "bytedance-research/ChatTS-8B"
17
  AVAILABLE_MODEL_NAMES = [
18
- "bytedance-research/ChatTS-8B",
19
- "bytedance-research/ChatTS-14B"
20
  ]
21
 
22
  MODEL_REGISTRY = {}
23
 
24
  for name in AVAILABLE_MODEL_NAMES:
25
  print(f"Loading model into memory: {name}")
26
- tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
27
- proc = AutoProcessor.from_pretrained(name, trust_remote_code=True, tokenizer=tok)
 
28
  mdl = AutoModelForCausalLM.from_pretrained(
29
- name,
30
  trust_remote_code=True,
31
  device_map="auto",
32
  torch_dtype=torch.float16
@@ -66,8 +67,9 @@ def load_model_by_name(name: str):
66
 
67
 
68
  def switch_model(selected_model_name: str):
69
- """Wrapper for Gradio to switch models; returns status text."""
70
- return load_model_by_name(selected_model_name)
 
71
 
72
  # ─── HELPER FUNCTIONS ──────────────────────────────────────────────────────────
73
 
@@ -332,27 +334,20 @@ with gr.Blocks(title="ChatTS Demo") as demo:
332
  with gr.Column(scale=1):
333
  # Model selection UI
334
  model_radio = gr.Radio(
335
- choices=["bytedance-research/ChatTS-8B", "bytedance-research/ChatTS-14B"],
336
  value=CURRENT_MODEL_NAME,
337
  label="Model Version"
338
  )
339
 
340
- model_btn = gr.Button("Load Model")
341
-
342
- model_status = gr.Textbox(
343
- label="Model Status",
344
- value=f"Models in memory: {', '.join(AVAILABLE_MODEL_NAMES)}. Active: {CURRENT_MODEL_NAME}",
345
- interactive=False
346
- )
347
-
348
  upload = gr.File(
349
  label="Upload CSV File",
350
  file_types=[".csv"],
351
- type="filepath"
 
352
  )
353
 
354
  prompt_input = gr.Textbox(
355
- lines=6,
356
  placeholder="Enter your question here...",
357
  label="Analysis Prompt",
358
  value="Please analyze all the given time series and provide insights about the local fluctuations in the time series in detail."
@@ -362,11 +357,11 @@ with gr.Blocks(title="ChatTS Demo") as demo:
362
 
363
  with gr.Column(scale=2):
364
  series_selector = gr.Dropdown(
365
- label="Select a Column to Visualize",
366
  choices=[],
367
  value=None
368
  )
369
- plot_out = gr.LinePlot(value=pd.DataFrame(), label="Time Series Visualization")
370
  file_status = gr.Textbox(
371
  label="File Status",
372
  interactive=False,
@@ -410,11 +405,11 @@ with gr.Blocks(title="ChatTS Demo") as demo:
410
  outputs=[text_out]
411
  )
412
 
413
- # Wire model loading button
414
- model_btn.click(
415
  fn=switch_model,
416
  inputs=[model_radio],
417
- outputs=[model_status]
418
  )
419
 
420
  if __name__ == '__main__':
 
13
 
14
  # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
15
  # Default to 8B but keep both variants resident on the GPU.
16
+ DEFAULT_MODEL_NAME = "ChatTS-8B"
17
  AVAILABLE_MODEL_NAMES = [
18
+ "ChatTS-8B",
19
+ "ChatTS-14B"
20
  ]
21
 
22
  MODEL_REGISTRY = {}
23
 
24
  for name in AVAILABLE_MODEL_NAMES:
25
  print(f"Loading model into memory: {name}")
26
+ model_path = "bytedance-research/" + name
27
+ tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
28
+ proc = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, tokenizer=tok)
29
  mdl = AutoModelForCausalLM.from_pretrained(
30
+ model_path,
31
  trust_remote_code=True,
32
  device_map="auto",
33
  torch_dtype=torch.float16
 
67
 
68
 
69
  def switch_model(selected_model_name: str):
70
+ """Wrapper for Gradio to switch models via radio selection."""
71
+ load_model_by_name(selected_model_name)
72
+ return None
73
 
74
  # ─── HELPER FUNCTIONS ──────────────────────────────────────────────────────────
75
 
 
334
  with gr.Column(scale=1):
335
  # Model selection UI
336
  model_radio = gr.Radio(
337
+ choices=["ChatTS-8B", "ChatTS-14B"],
338
  value=CURRENT_MODEL_NAME,
339
  label="Model Version"
340
  )
341
 
 
 
 
 
 
 
 
 
342
  upload = gr.File(
343
  label="Upload CSV File",
344
  file_types=[".csv"],
345
+ type="filepath",
346
+ height=120
347
  )
348
 
349
  prompt_input = gr.Textbox(
350
+ lines=5,
351
  placeholder="Enter your question here...",
352
  label="Analysis Prompt",
353
  value="Please analyze all the given time series and provide insights about the local fluctuations in the time series in detail."
 
357
 
358
  with gr.Column(scale=2):
359
  series_selector = gr.Dropdown(
360
+ label="Select a Channel to Visualize (All Channels Will be Input to ChatTS)",
361
  choices=[],
362
  value=None
363
  )
364
+ plot_out = gr.LinePlot(value=pd.DataFrame(), label="Channel Visualization (All Channels Will be Input to ChatTS)")
365
  file_status = gr.Textbox(
366
  label="File Status",
367
  interactive=False,
 
405
  outputs=[text_out]
406
  )
407
 
408
+ # Model selection reacts immediately; no separate button needed
409
+ model_radio.change(
410
  fn=switch_model,
411
  inputs=[model_radio],
412
+ outputs=[]
413
  )
414
 
415
  if __name__ == '__main__':