Fix bugs
Browse files- __pycache__/app.cpython-313.pyc +0 -0
- app.py +19 -24
__pycache__/app.cpython-313.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
app.py
CHANGED
|
@@ -13,20 +13,21 @@ from transformers import (
|
|
| 13 |
|
| 14 |
# βββ MODEL SETUP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
# Default to 8B but keep both variants resident on the GPU.
|
| 16 |
-
DEFAULT_MODEL_NAME = "
|
| 17 |
AVAILABLE_MODEL_NAMES = [
|
| 18 |
-
"
|
| 19 |
-
"
|
| 20 |
]
|
| 21 |
|
| 22 |
MODEL_REGISTRY = {}
|
| 23 |
|
| 24 |
for name in AVAILABLE_MODEL_NAMES:
|
| 25 |
print(f"Loading model into memory: {name}")
|
| 26 |
-
|
| 27 |
-
|
|
|
|
| 28 |
mdl = AutoModelForCausalLM.from_pretrained(
|
| 29 |
-
|
| 30 |
trust_remote_code=True,
|
| 31 |
device_map="auto",
|
| 32 |
torch_dtype=torch.float16
|
|
@@ -66,8 +67,9 @@ def load_model_by_name(name: str):
|
|
| 66 |
|
| 67 |
|
| 68 |
def switch_model(selected_model_name: str):
|
| 69 |
-
"""Wrapper for Gradio to switch models
|
| 70 |
-
|
|
|
|
| 71 |
|
| 72 |
# βββ HELPER FUNCTIONS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
|
|
@@ -332,27 +334,20 @@ with gr.Blocks(title="ChatTS Demo") as demo:
|
|
| 332 |
with gr.Column(scale=1):
|
| 333 |
# Model selection UI
|
| 334 |
model_radio = gr.Radio(
|
| 335 |
-
choices=["
|
| 336 |
value=CURRENT_MODEL_NAME,
|
| 337 |
label="Model Version"
|
| 338 |
)
|
| 339 |
|
| 340 |
-
model_btn = gr.Button("Load Model")
|
| 341 |
-
|
| 342 |
-
model_status = gr.Textbox(
|
| 343 |
-
label="Model Status",
|
| 344 |
-
value=f"Models in memory: {', '.join(AVAILABLE_MODEL_NAMES)}. Active: {CURRENT_MODEL_NAME}",
|
| 345 |
-
interactive=False
|
| 346 |
-
)
|
| 347 |
-
|
| 348 |
upload = gr.File(
|
| 349 |
label="Upload CSV File",
|
| 350 |
file_types=[".csv"],
|
| 351 |
-
type="filepath"
|
|
|
|
| 352 |
)
|
| 353 |
|
| 354 |
prompt_input = gr.Textbox(
|
| 355 |
-
lines=
|
| 356 |
placeholder="Enter your question here...",
|
| 357 |
label="Analysis Prompt",
|
| 358 |
value="Please analyze all the given time series and provide insights about the local fluctuations in the time series in detail."
|
|
@@ -362,11 +357,11 @@ with gr.Blocks(title="ChatTS Demo") as demo:
|
|
| 362 |
|
| 363 |
with gr.Column(scale=2):
|
| 364 |
series_selector = gr.Dropdown(
|
| 365 |
-
label="Select a
|
| 366 |
choices=[],
|
| 367 |
value=None
|
| 368 |
)
|
| 369 |
-
plot_out = gr.LinePlot(value=pd.DataFrame(), label="
|
| 370 |
file_status = gr.Textbox(
|
| 371 |
label="File Status",
|
| 372 |
interactive=False,
|
|
@@ -410,11 +405,11 @@ with gr.Blocks(title="ChatTS Demo") as demo:
|
|
| 410 |
outputs=[text_out]
|
| 411 |
)
|
| 412 |
|
| 413 |
-
#
|
| 414 |
-
|
| 415 |
fn=switch_model,
|
| 416 |
inputs=[model_radio],
|
| 417 |
-
outputs=[
|
| 418 |
)
|
| 419 |
|
| 420 |
if __name__ == '__main__':
|
|
|
|
| 13 |
|
| 14 |
# βββ MODEL SETUP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
# Default to 8B but keep both variants resident on the GPU.
|
| 16 |
+
DEFAULT_MODEL_NAME = "ChatTS-8B"
|
| 17 |
AVAILABLE_MODEL_NAMES = [
|
| 18 |
+
"ChatTS-8B",
|
| 19 |
+
"ChatTS-14B"
|
| 20 |
]
|
| 21 |
|
| 22 |
MODEL_REGISTRY = {}
|
| 23 |
|
| 24 |
for name in AVAILABLE_MODEL_NAMES:
|
| 25 |
print(f"Loading model into memory: {name}")
|
| 26 |
+
model_path = "bytedance-research/" + name
|
| 27 |
+
tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 28 |
+
proc = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, tokenizer=tok)
|
| 29 |
mdl = AutoModelForCausalLM.from_pretrained(
|
| 30 |
+
model_path,
|
| 31 |
trust_remote_code=True,
|
| 32 |
device_map="auto",
|
| 33 |
torch_dtype=torch.float16
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
def switch_model(selected_model_name: str):
|
| 70 |
+
"""Wrapper for Gradio to switch models via radio selection."""
|
| 71 |
+
load_model_by_name(selected_model_name)
|
| 72 |
+
return None
|
| 73 |
|
| 74 |
# βββ HELPER FUNCTIONS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
|
|
|
|
| 334 |
with gr.Column(scale=1):
|
| 335 |
# Model selection UI
|
| 336 |
model_radio = gr.Radio(
|
| 337 |
+
choices=["ChatTS-8B", "ChatTS-14B"],
|
| 338 |
value=CURRENT_MODEL_NAME,
|
| 339 |
label="Model Version"
|
| 340 |
)
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
upload = gr.File(
|
| 343 |
label="Upload CSV File",
|
| 344 |
file_types=[".csv"],
|
| 345 |
+
type="filepath",
|
| 346 |
+
height=120
|
| 347 |
)
|
| 348 |
|
| 349 |
prompt_input = gr.Textbox(
|
| 350 |
+
lines=5,
|
| 351 |
placeholder="Enter your question here...",
|
| 352 |
label="Analysis Prompt",
|
| 353 |
value="Please analyze all the given time series and provide insights about the local fluctuations in the time series in detail."
|
|
|
|
| 357 |
|
| 358 |
with gr.Column(scale=2):
|
| 359 |
series_selector = gr.Dropdown(
|
| 360 |
+
label="Select a Channel to Visualize (All Channels Will be Input to ChatTS)",
|
| 361 |
choices=[],
|
| 362 |
value=None
|
| 363 |
)
|
| 364 |
+
plot_out = gr.LinePlot(value=pd.DataFrame(), label="Channel Visualization (All Channels Will be Input to ChatTS)")
|
| 365 |
file_status = gr.Textbox(
|
| 366 |
label="File Status",
|
| 367 |
interactive=False,
|
|
|
|
| 405 |
outputs=[text_out]
|
| 406 |
)
|
| 407 |
|
| 408 |
+
# Model selection reacts immediately; no separate button needed
|
| 409 |
+
model_radio.change(
|
| 410 |
fn=switch_model,
|
| 411 |
inputs=[model_radio],
|
| 412 |
+
outputs=[]
|
| 413 |
)
|
| 414 |
|
| 415 |
if __name__ == '__main__':
|