|
|
import spaces |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
from threading import Thread |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
AutoProcessor, |
|
|
TextIteratorStreamer |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_MODEL_NAME = "ChatTS-8B" |
|
|
AVAILABLE_MODEL_NAMES = [ |
|
|
"ChatTS-8B", |
|
|
"ChatTS-14B" |
|
|
] |
|
|
|
|
|
MODEL_REGISTRY = {} |
|
|
|
|
|
for name in AVAILABLE_MODEL_NAMES: |
|
|
print(f"Loading model into memory: {name}") |
|
|
model_path = "bytedance-research/" + name |
|
|
tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
proc = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, tokenizer=tok) |
|
|
mdl = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
mdl.eval() |
|
|
MODEL_REGISTRY[name] = { |
|
|
"tokenizer": tok, |
|
|
"processor": proc, |
|
|
"model": mdl |
|
|
} |
|
|
|
|
|
CURRENT_MODEL_NAME = DEFAULT_MODEL_NAME |
|
|
|
|
|
tokenizer = MODEL_REGISTRY[CURRENT_MODEL_NAME]["tokenizer"] |
|
|
processor = MODEL_REGISTRY[CURRENT_MODEL_NAME]["processor"] |
|
|
model = MODEL_REGISTRY[CURRENT_MODEL_NAME]["model"] |
|
|
|
|
|
|
|
|
def load_model_by_name(name: str): |
|
|
"""Activate the preloaded model by name without reloading weights.""" |
|
|
global tokenizer, processor, model, CURRENT_MODEL_NAME |
|
|
|
|
|
if name not in MODEL_REGISTRY: |
|
|
return f"Model not available: {name}" |
|
|
|
|
|
if name == CURRENT_MODEL_NAME: |
|
|
return f"Model already selected: {name}" |
|
|
|
|
|
CURRENT_MODEL_NAME = name |
|
|
tokenizer = MODEL_REGISTRY[name]["tokenizer"] |
|
|
processor = MODEL_REGISTRY[name]["processor"] |
|
|
model = MODEL_REGISTRY[name]["model"] |
|
|
model.eval() |
|
|
|
|
|
print(f"Activated model: {name}") |
|
|
return name |
|
|
|
|
|
|
|
|
def switch_model(selected_model_name: str): |
|
|
"""Wrapper for Gradio to switch models via radio selection.""" |
|
|
|
|
|
_ = load_model_by_name(selected_model_name) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def create_default_timeseries(): |
|
|
"""Create default time series with sudden increase""" |
|
|
x1 = np.arange(256) |
|
|
x2 = np.arange(256) |
|
|
ts1 = np.sin(x1 / 10) * 5.0 |
|
|
ts1[103:] -= 10.0 |
|
|
ts2 = x2 * 0.01 |
|
|
ts2[100] += 10.0 |
|
|
|
|
|
df = pd.DataFrame({ |
|
|
"TS1": ts1, |
|
|
"TS2": ts2 |
|
|
}) |
|
|
return df |
|
|
|
|
|
def process_csv_file(csv_file): |
|
|
"""Process CSV file and return DataFrame with validation""" |
|
|
if csv_file is None: |
|
|
return None, "No file uploaded" |
|
|
|
|
|
try: |
|
|
df = pd.read_csv(csv_file.name) |
|
|
|
|
|
|
|
|
df.columns = [str(c).strip() for c in df.columns] |
|
|
df = df.loc[:, [c for c in df.columns if c]] |
|
|
df = df.dropna(axis=1, how="all") |
|
|
print(f"[LOG] File {csv_file.name} loaded. {df.columns=}") |
|
|
|
|
|
if df.shape[1] == 0: |
|
|
return None, "No valid time-series columns found." |
|
|
if df.shape[1] > 15: |
|
|
return None, f"Too many series ({df.shape[1]}). Max allowed = 15." |
|
|
|
|
|
|
|
|
ts_names, ts_list = [], [] |
|
|
for name in df.columns: |
|
|
series = df[name] |
|
|
|
|
|
if not pd.api.types.is_float_dtype(series): |
|
|
try: |
|
|
series = pd.to_numeric(series, errors='coerce') |
|
|
except: |
|
|
return None, f"Series '{name}' cannot be converted to float type." |
|
|
|
|
|
|
|
|
last_valid = series.last_valid_index() |
|
|
if last_valid is None: |
|
|
continue |
|
|
trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32) |
|
|
length = trimmed.shape[0] |
|
|
if length < 16 or length > 1024: |
|
|
return None, f"Series '{name}' length {length} invalid. Must be 16 to 1024." |
|
|
ts_names.append(name) |
|
|
ts_list.append(trimmed) |
|
|
|
|
|
if not ts_list: |
|
|
return None, "All time series are empty after trimming NaNs." |
|
|
|
|
|
|
|
|
return df, f"Successfully loaded {len(ts_names)} time series: {', '.join(ts_names)}" |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"Error processing file: {str(e)}" |
|
|
|
|
|
def preview_csv(csv_file, use_default): |
|
|
"""Preview uploaded CSV file immediately""" |
|
|
if csv_file is None: |
|
|
return gr.LinePlot(value=pd.DataFrame()), "Please upload a CSV file first", gr.Dropdown(), False |
|
|
|
|
|
df, message = process_csv_file(csv_file) |
|
|
|
|
|
if df is None: |
|
|
return gr.LinePlot(value=pd.DataFrame()), message, gr.Dropdown(), False |
|
|
|
|
|
|
|
|
column_choices = list(df.columns) |
|
|
|
|
|
|
|
|
first_column = column_choices[0] |
|
|
df_with_index = df.copy() |
|
|
df_with_index["_internal_idx"] = np.arange(len(df[first_column].values)) |
|
|
plot = gr.LinePlot( |
|
|
df_with_index, |
|
|
x="_internal_idx", |
|
|
y=first_column, |
|
|
title=f"Time Series: {first_column}" |
|
|
) |
|
|
|
|
|
|
|
|
dropdown = gr.Dropdown( |
|
|
choices=column_choices, |
|
|
value=first_column, |
|
|
label="Select a Column to Visualize" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
return plot, message, dropdown, False |
|
|
|
|
|
def clear_csv(): |
|
|
"""Clear uploaded CSV file immediately""" |
|
|
df, message = process_csv_file(None) |
|
|
|
|
|
return gr.LinePlot(value=pd.DataFrame()), message, gr.Dropdown() |
|
|
|
|
|
|
|
|
def update_plot(csv_file, selected_column, use_default_state): |
|
|
"""Update plot based on selected column""" |
|
|
if (csv_file is None and not use_default_state) or selected_column is None : |
|
|
return gr.LinePlot(value=pd.DataFrame()) |
|
|
|
|
|
if csv_file is None and use_default_state: |
|
|
df = create_default_timeseries() |
|
|
else: |
|
|
df, _ = process_csv_file(csv_file) |
|
|
if df is None: |
|
|
return gr.LinePlot(value=pd.DataFrame()) |
|
|
|
|
|
df_with_index = df.copy() |
|
|
df_with_index["_internal_idx"] = np.arange(len(df[selected_column].values)) |
|
|
|
|
|
plot = gr.LinePlot( |
|
|
df_with_index, |
|
|
x="_internal_idx", |
|
|
y=selected_column, |
|
|
title=f"Time Series: {selected_column}" |
|
|
) |
|
|
|
|
|
return plot |
|
|
|
|
|
def initialize_interface(): |
|
|
"""Initialize interface with default time series""" |
|
|
df = create_default_timeseries() |
|
|
column_choices = list(df.columns) |
|
|
first_column = column_choices[0] |
|
|
|
|
|
df_with_index = df.copy() |
|
|
df_with_index["_internal_idx"] = np.arange(len(df[first_column].values)) |
|
|
|
|
|
plot = gr.LinePlot( |
|
|
df_with_index, |
|
|
x="_internal_idx", |
|
|
y=first_column, |
|
|
title=f"Time Series: {first_column}" |
|
|
) |
|
|
|
|
|
dropdown = gr.Dropdown( |
|
|
choices=column_choices, |
|
|
value=first_column, |
|
|
label="Select a Column to Visualize" |
|
|
) |
|
|
|
|
|
message = "Using default time series (TS1 and TS2). Please select a time series from the dropdown box above for visualization." |
|
|
|
|
|
return plot, message, dropdown, True |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def infer_chatts_stream(prompt: str, csv_file, use_default, model_name): |
|
|
""" |
|
|
Streaming version of ChatTS inference |
|
|
""" |
|
|
|
|
|
switch_model(model_name) |
|
|
|
|
|
|
|
|
if not prompt.strip(): |
|
|
yield "Please enter a prompt" |
|
|
return |
|
|
|
|
|
|
|
|
if csv_file is None and use_default: |
|
|
df = create_default_timeseries() |
|
|
error_msg = None |
|
|
else: |
|
|
df, error_msg = process_csv_file(csv_file) |
|
|
|
|
|
if df is None: |
|
|
yield "Please upload a CSV file first or the file contains errors" |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
ts_names, ts_list = [], [] |
|
|
for name in df.columns: |
|
|
series = df[name] |
|
|
last_valid = series.last_valid_index() |
|
|
if last_valid is not None: |
|
|
trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32) |
|
|
ts_names.append(name) |
|
|
ts_list.append(trimmed) |
|
|
|
|
|
if not ts_list: |
|
|
yield "No valid time series data found. Please upload time series first." |
|
|
return |
|
|
|
|
|
|
|
|
clean_prompt = prompt.replace("<ts>", "").replace("<ts/>", "") |
|
|
|
|
|
|
|
|
prefix = f"I have {len(ts_list)} time series:\n" |
|
|
for name, arr in zip(ts_names, ts_list): |
|
|
prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n" |
|
|
|
|
|
full_prompt = f"<|im_start|>system\nYou are a helpful assistant. Your name is ChatTS. You can analyze time series data and provide insights. If user asks who you are, you should give your name and capabilities in the language of the prompt. If user has no format requirement, always output a step-by-step analysis about the time series attributes that mentioned in the question first, and then give a detailed result about the given question.<|im_end|><|im_start|>user\n{prefix}{clean_prompt}<|im_end|><|im_start|>assistant\n" |
|
|
|
|
|
print(f"[LOG] model={CURRENT_MODEL_NAME}, {clean_prompt=}, {len(ts_list)=}") |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
text=[full_prompt], |
|
|
timeseries=ts_list, |
|
|
padding=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) |
|
|
inputs.update({ |
|
|
"max_new_tokens": 512, |
|
|
"streamer": streamer, |
|
|
"temperature": 0.3 |
|
|
}) |
|
|
thread = Thread( |
|
|
target=model.generate, |
|
|
kwargs=inputs |
|
|
) |
|
|
thread.start() |
|
|
|
|
|
model_output = "" |
|
|
for new_text in streamer: |
|
|
model_output += new_text |
|
|
yield model_output |
|
|
|
|
|
except Exception as e: |
|
|
yield f"Error during inference: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="ChatTS Demo") as demo: |
|
|
gr.Markdown("## ChatTS: Time Series Understanding and Reasoning") |
|
|
gr.HTML("""<div style="display:flex;justify-content: center"> |
|
|
<a href="https://github.com/NetmanAIOps/ChatTS"><img alt="github" src="https://img.shields.io/badge/Code-GitHub-blue"></a> |
|
|
<a href="https://huggingface.co/bytedance-research/ChatTS-14B"><img alt="github" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-FFD21E"></a> |
|
|
<a href="https://arxiv.org/abs/2412.03104"><img alt="preprint" src="https://img.shields.io/static/v1?label=arXiv&message=2412.03104&color=B31B1B&logo=arXiv"></a> |
|
|
</div>""") |
|
|
gr.Markdown("Try ChatTS with the default time series, or upload a CSV file (Example: [ts_example.csv](https://github.com/NetManAIOps/ChatTS/blob/main/demo/ts_example.csv)) containing UTS/MTS where each column is a dimension (no index column). All columns will be used as input of ChatTS automatically.") |
|
|
gr.Markdown("The length should be between 16 and 1024, with 15 time series at most. Please use English to ask questions. If you like ChatTS, kindly star our [GitHub repo](https://github.com/NetmanAIOps/ChatTS).") |
|
|
|
|
|
|
|
|
use_default_state = gr.State(value=True) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
model_radio = gr.Radio( |
|
|
choices=["ChatTS-8B", "ChatTS-14B"], |
|
|
value=CURRENT_MODEL_NAME, |
|
|
label="Model Version" |
|
|
) |
|
|
|
|
|
upload = gr.File( |
|
|
label="Upload CSV File", |
|
|
file_types=[".csv"], |
|
|
type="filepath", |
|
|
height=80 |
|
|
) |
|
|
|
|
|
prompt_input = gr.Textbox( |
|
|
lines=5, |
|
|
placeholder="Enter your question here...", |
|
|
label="Analysis Prompt", |
|
|
value="Please analyze all the given time series and provide insights about the local fluctuations in the time series in detail." |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
series_selector = gr.Dropdown( |
|
|
label="Select a Channel to Visualize (All Channels Will be Input to ChatTS)", |
|
|
choices=[], |
|
|
value=None |
|
|
) |
|
|
plot_out = gr.LinePlot(value=pd.DataFrame(), label="Channel Visualization (All Channels Will be Input to ChatTS)") |
|
|
file_status = gr.Textbox( |
|
|
label="File Status", |
|
|
interactive=False, |
|
|
lines=1 |
|
|
) |
|
|
run_btn = gr.Button("Run ChatTS", variant="primary") |
|
|
|
|
|
text_out = gr.Textbox( |
|
|
lines=10, |
|
|
label="ChatTS Analysis Results", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
demo.load( |
|
|
fn=initialize_interface, |
|
|
outputs=[plot_out, file_status, series_selector, use_default_state] |
|
|
) |
|
|
|
|
|
|
|
|
upload.upload( |
|
|
fn=preview_csv, |
|
|
inputs=[upload, use_default_state], |
|
|
outputs=[plot_out, file_status, series_selector, use_default_state] |
|
|
) |
|
|
|
|
|
upload.clear( |
|
|
fn=clear_csv, |
|
|
inputs=[], |
|
|
outputs=[plot_out, file_status, series_selector] |
|
|
) |
|
|
|
|
|
series_selector.change( |
|
|
fn=update_plot, |
|
|
inputs=[upload, series_selector, use_default_state], |
|
|
outputs=[plot_out] |
|
|
) |
|
|
|
|
|
run_btn.click( |
|
|
fn=infer_chatts_stream, |
|
|
inputs=[prompt_input, upload, use_default_state, model_radio], |
|
|
outputs=[text_out] |
|
|
) |
|
|
|
|
|
|
|
|
model_radio.change( |
|
|
fn=switch_model, |
|
|
inputs=[model_radio], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
demo.launch() |
|
|
|