Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
import os, json, time, random, threading, logging
|
| 7 |
from datetime import datetime, timezone
|
| 8 |
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
|
|
@@ -18,77 +15,139 @@ TOKENS_PER_PROMPT = 2048
|
|
| 18 |
SECS_PER_TOKEN = 15
|
| 19 |
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
|
| 20 |
|
| 21 |
-
|
| 22 |
logging.basicConfig(level=logging.INFO)
|
| 23 |
log = logging.getLogger()
|
| 24 |
|
| 25 |
-
def _rj(p,d):
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
except:
|
| 29 |
-
|
| 30 |
|
| 31 |
-
def _aw(p,o):
|
| 32 |
-
t=p+".tmp"; open(t,"w",encoding="utf-8").write(json.dumps(o,ensure_ascii=False,indent=2)); os.replace(t,p)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype=torch.float32,low_cpu_mem_usage=False,token=tok)
|
| 41 |
-
model.to("cpu");model.eval()
|
| 42 |
-
log.info("model up")
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
def _init():
|
| 47 |
-
|
| 48 |
-
if not
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
_aw(STATE_PATH,
|
| 52 |
-
return
|
| 53 |
-
|
| 54 |
-
def _es(
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
return f"{h}h {m}m {s}s"
|
| 57 |
|
| 58 |
def _loop():
|
| 59 |
while True:
|
| 60 |
-
with lock: s=_init()
|
| 61 |
-
if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
|
| 62 |
-
c=s["p"]+s["g"]
|
| 63 |
-
ids=tokenizer(c,return_tensors="pt",truncation=True,max_length=MAX_CTX).input_ids
|
| 64 |
-
with torch.no_grad(): out=model.generate(ids,max_new_tokens=1,do_sample=True,temperature=TEMP,top_p=TOP_P)
|
| 65 |
-
nt=tokenizer.decode(out[0,-1],skip_special_tokens=True,clean_up_tokenization_spaces=False)
|
| 66 |
with lock:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
time.sleep(SECS_PER_TOKEN)
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
def _fetch():
|
| 74 |
-
|
| 75 |
-
if not
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
with gr.Blocks(theme="darkdefault") as demo:
|
| 87 |
-
gr.Markdown(
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import os, json, time, random, threading, logging
|
| 4 |
from datetime import datetime, timezone
|
| 5 |
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
|
|
|
|
| 15 |
SECS_PER_TOKEN = 15
|
| 16 |
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
|
| 17 |
|
|
|
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
log = logging.getLogger()
|
| 20 |
|
| 21 |
+
def _rj(p, d):
|
| 22 |
+
try:
|
| 23 |
+
return json.load(open(p, encoding="utf-8"))
|
| 24 |
+
except:
|
| 25 |
+
return d
|
| 26 |
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
def _aw(p, o):
|
| 29 |
+
t = p + ".tmp"
|
| 30 |
+
open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2))
|
| 31 |
+
os.replace(t, p)
|
| 32 |
|
| 33 |
+
prompts = _rj(PROMPTS_PATH, [])
|
| 34 |
+
if not prompts:
|
| 35 |
+
raise Exception("No prompts found in full_prompts.json")
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
tok = os.environ.get("HF_READ_TOKEN")
|
| 38 |
+
log.info("Loading model...")
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
|
| 40 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 41 |
+
MODEL_NAME,
|
| 42 |
+
torch_dtype=torch.float32,
|
| 43 |
+
low_cpu_mem_usage=False,
|
| 44 |
+
token=tok
|
| 45 |
+
)
|
| 46 |
+
model.to("cpu"); model.eval()
|
| 47 |
+
log.info("Model is ready.")
|
| 48 |
+
|
| 49 |
+
lock = threading.Lock()
|
| 50 |
|
| 51 |
def _init():
|
| 52 |
+
state = _rj(STATE_PATH, {})
|
| 53 |
+
if not state or state.get("finished"):
|
| 54 |
+
idx = random.randrange(len(prompts))
|
| 55 |
+
state = {"i": idx, "p": prompts[idx], "g": "", "c": 0, "t": time.time(), "finished": False}
|
| 56 |
+
_aw(STATE_PATH, state)
|
| 57 |
+
return state
|
| 58 |
+
|
| 59 |
+
def _es(start_time):
|
| 60 |
+
elapsed = int(time.time() - start_time)
|
| 61 |
+
h, rem = divmod(elapsed, 3600)
|
| 62 |
+
m, s = divmod(rem, 60)
|
| 63 |
return f"{h}h {m}m {s}s"
|
| 64 |
|
| 65 |
def _loop():
|
| 66 |
while True:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
with lock:
|
| 68 |
+
st = _init()
|
| 69 |
+
if st["finished"]:
|
| 70 |
+
time.sleep(SECS_PER_TOKEN)
|
| 71 |
+
continue
|
| 72 |
+
context = st["p"] + st["g"]
|
| 73 |
+
ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
out = model.generate(
|
| 76 |
+
ids,
|
| 77 |
+
max_new_tokens=1,
|
| 78 |
+
do_sample=True,
|
| 79 |
+
temperature=TEMP,
|
| 80 |
+
top_p=TOP_P
|
| 81 |
+
)
|
| 82 |
+
next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
| 83 |
+
with lock:
|
| 84 |
+
st["g"] += next_token
|
| 85 |
+
st["c"] += 1
|
| 86 |
+
if st["c"] >= TOKENS_PER_PROMPT:
|
| 87 |
+
st["finished"] = True
|
| 88 |
+
_aw(STATE_PATH, st)
|
| 89 |
time.sleep(SECS_PER_TOKEN)
|
| 90 |
+
|
| 91 |
+
threading.Thread(target=_loop, daemon=True).start()
|
| 92 |
|
| 93 |
def _fetch():
|
| 94 |
+
state = _rj(STATE_PATH, {})
|
| 95 |
+
if not state:
|
| 96 |
+
return "...", "", "0h 0m 0s"
|
| 97 |
+
return state["p"], state["g"], _es(state["t"])
|
| 98 |
+
|
| 99 |
+
def _submit_prediction(detailed, summary):
|
| 100 |
+
det = detailed.strip()
|
| 101 |
+
if not det:
|
| 102 |
+
return gr.update(value="Please enter at least a detailed prediction."), gr.update(value=""), gr.update(value="")
|
| 103 |
+
prompt_text, oracle_resp, elapsed = _fetch()
|
| 104 |
+
record = {
|
| 105 |
+
"ts": datetime.now(timezone.utc).isoformat(),
|
| 106 |
+
"prompt": prompt_text,
|
| 107 |
+
"time": elapsed,
|
| 108 |
+
"resp": oracle_resp,
|
| 109 |
+
"prediction": det,
|
| 110 |
+
"summary": summary.strip()
|
| 111 |
+
}
|
| 112 |
+
with lock:
|
| 113 |
+
open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 114 |
+
return gr.update(value="Prediction logged!"), gr.update(value=""), gr.update(value="")
|
| 115 |
|
| 116 |
with gr.Blocks(theme="darkdefault") as demo:
|
| 117 |
+
gr.Markdown(
|
| 118 |
+
"# What Comes Next\n"
|
| 119 |
+
"Enter what you think will come next in the text.\n"
|
| 120 |
+
"Provide a detailed continuation and optionally a brief summary for context."
|
| 121 |
+
)
|
| 122 |
+
prompt_md = gr.Markdown()
|
| 123 |
+
oracle_output = gr.Textbox(lines=10, interactive=False, label="Oracle Response")
|
| 124 |
+
time_info = gr.Textbox(interactive=False, label="Elapsed Time")
|
| 125 |
+
|
| 126 |
+
with gr.Row():
|
| 127 |
+
prompt_md, oracle_output, time_info
|
| 128 |
+
|
| 129 |
+
detailed = gr.Textbox(
|
| 130 |
+
label="Your Detailed Prediction",
|
| 131 |
+
placeholder="Enter the full text continuation you expect...",
|
| 132 |
+
lines=3
|
| 133 |
+
)
|
| 134 |
+
summary = gr.Textbox(
|
| 135 |
+
label="Prediction Summary (Optional)",
|
| 136 |
+
placeholder="Optionally, summarize your prediction in a few words...",
|
| 137 |
+
lines=2
|
| 138 |
+
)
|
| 139 |
+
status = gr.Textbox(interactive=False, label="Status")
|
| 140 |
+
submit_btn = gr.Button("Submit Prediction")
|
| 141 |
+
refresh_btn = gr.Button("Refresh Oracle")
|
| 142 |
+
|
| 143 |
+
demo.load(_fetch, outputs=[prompt_md, oracle_output, time_info])
|
| 144 |
+
refresh_btn.click(_fetch, outputs=[prompt_md, oracle_output, time_info])
|
| 145 |
+
submit_btn.click(
|
| 146 |
+
_submit_prediction,
|
| 147 |
+
inputs=[detailed, summary],
|
| 148 |
+
outputs=[status, detailed, summary]
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 153 |
+
|