Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import os, json, time, random, threading, logging | |
| from datetime import datetime, timezone | |
| import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count()) | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" | |
| PROMPTS_PATH = "full_prompts.json" | |
| STATE_PATH = "current_state.json" | |
| DATA_PATH = "data.json" | |
| TOKENS_PER_PROMPT = 2048 | |
| SECS_PER_TOKEN = 15 | |
| TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192 | |
| logging.basicConfig(level=logging.INFO) | |
| log = logging.getLogger() | |
| def _rj(p, d): | |
| try: | |
| return json.load(open(p, encoding="utf-8")) | |
| except: | |
| return d | |
| def _aw(p, o): | |
| t = p + ".tmp" | |
| open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)) | |
| os.replace(t, p) | |
| prompts = _rj(PROMPTS_PATH, []) | |
| if not prompts: | |
| raise Exception("No prompts found in full_prompts.json") | |
| tok = os.environ.get("HF_READ_TOKEN") | |
| log.info("Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=False, | |
| token=tok | |
| ) | |
| model.to("cpu"); model.eval() | |
| log.info("Model is ready.") | |
| lock = threading.Lock() | |
| def _init(): | |
| state = _rj(STATE_PATH, {}) | |
| if not state or state.get("finished"): | |
| idx = random.randrange(len(prompts)) | |
| state = {"i": idx, "p": prompts[idx], "g": "", "c": 0, "t": time.time(), "finished": False} | |
| _aw(STATE_PATH, state) | |
| return state | |
| def _es(start_time): | |
| elapsed = int(time.time() - start_time) | |
| h, rem = divmod(elapsed, 3600) | |
| m, s = divmod(rem, 60) | |
| return f"{h}h {m}m {s}s" | |
| def _loop(): | |
| while True: | |
| with lock: | |
| st = _init() | |
| if st["finished"]: | |
| time.sleep(SECS_PER_TOKEN) | |
| continue | |
| context = st["p"] + st["g"] | |
| ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids | |
| with torch.no_grad(): | |
| out = model.generate( | |
| ids, | |
| max_new_tokens=1, | |
| do_sample=True, | |
| temperature=TEMP, | |
| top_p=TOP_P | |
| ) | |
| next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| with lock: | |
| st["g"] += next_token | |
| st["c"] += 1 | |
| if st["c"] >= TOKENS_PER_PROMPT: | |
| st["finished"] = True | |
| _aw(STATE_PATH, st) | |
| time.sleep(SECS_PER_TOKEN) | |
| threading.Thread(target=_loop, daemon=True).start() | |
| def _fetch(): | |
| state = _rj(STATE_PATH, {}) | |
| if not state: | |
| return "...", "", "0h 0m 0s" | |
| return state["p"], state["g"], _es(state["t"]) | |
| def _submit_prediction(detailed, summary): | |
| det = detailed.strip() | |
| if not det: | |
| return gr.update(value="Please enter at least a detailed prediction."), gr.update(value=""), gr.update(value="") | |
| prompt_text, oracle_resp, elapsed = _fetch() | |
| record = { | |
| "ts": datetime.now(timezone.utc).isoformat(), | |
| "prompt": prompt_text, | |
| "time": elapsed, | |
| "resp": oracle_resp, | |
| "prediction": det, | |
| "summary": summary.strip() | |
| } | |
| with lock: | |
| open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(record, ensure_ascii=False) + "\n") | |
| return gr.update(value="Prediction logged!"), gr.update(value=""), gr.update(value="") | |
| with gr.Blocks(theme="darkdefault") as demo: | |
| gr.Markdown( | |
| "# What Comes Next\n" | |
| "Enter what you think will come next in the text.\n" | |
| "Provide a detailed continuation and optionally a brief summary for context." | |
| ) | |
| prompt_md = gr.Markdown() | |
| oracle_output = gr.Textbox(lines=10, interactive=False, label="Oracle Response") | |
| time_info = gr.Textbox(interactive=False, label="Elapsed Time") | |
| with gr.Row(): | |
| prompt_md, oracle_output, time_info | |
| detailed = gr.Textbox( | |
| label="Your Detailed Prediction", | |
| placeholder="Enter the full text continuation you expect...", | |
| lines=3 | |
| ) | |
| summary = gr.Textbox( | |
| label="Prediction Summary (Optional)", | |
| placeholder="Optionally, summarize your prediction in a few words...", | |
| lines=2 | |
| ) | |
| status = gr.Textbox(interactive=False, label="Status") | |
| submit_btn = gr.Button("Submit Prediction") | |
| refresh_btn = gr.Button("Refresh Oracle") | |
| demo.load(_fetch, outputs=[prompt_md, oracle_output, time_info]) | |
| refresh_btn.click(_fetch, outputs=[prompt_md, oracle_output, time_info]) | |
| submit_btn.click( | |
| _submit_prediction, | |
| inputs=[detailed, summary], | |
| outputs=[status, detailed, summary] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |