Spaces:
Runtime error
Runtime error
| import re | |
| import threading | |
| import gradio as gr | |
| import spaces | |
| import transformers | |
| from transformers import pipeline | |
| # λͺ¨λΈκ³Ό ν ν¬λμ΄μ λ‘λ© | |
| model_name = "meta-llama/Llama-3.1-8B-Instruct" | |
| if gr.NO_RELOAD: | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model_name, | |
| device_map="auto", | |
| torch_dtype="auto", | |
| ) | |
| # μ΅μ’ λ΅λ³μ κ°μ§νκΈ° μν λ§μ»€ | |
| ANSWER_MARKER = "**λ΅λ³**" | |
| # λ¨κ³λ³ μΆλ‘ μ μμνλ λ¬Έμ₯λ€ | |
| rethink_prepends = [ | |
| "μ, μ΄μ λ€μμ νμ ν΄μΌ ν©λλ€ ", | |
| "μ μκ°μλ ", | |
| "λ€μ μ¬νμ΄ λ§λμ§ νμΈν΄ λ³΄κ² μ΅λλ€ ", | |
| "λν κΈ°μ΅ν΄μΌ ν κ²μ ", | |
| "λ λ€λ₯Έ μ£Όλͺ©ν μ μ ", | |
| "κ·Έλ¦¬κ³ μ λ λ€μκ³Ό κ°μ μ¬μ€λ κΈ°μ΅ν©λλ€ ", | |
| "μ΄μ μΆ©λΆν μ΄ν΄νλ€κ³ μκ°ν©λλ€ ", | |
| ] | |
| # μ΅μ’ λ΅λ³ μμ±μ μν ν둬ννΈ μΆκ° | |
| final_answer_prompt = """ | |
| μ§κΈκΉμ§μ μΆλ‘ κ³Όμ μ λ°νμΌλ‘, μλ μ§λ¬Έμ μ¬μ©λ μΈμ΄λ‘ λ΅λ³νκ² μ΅λλ€: | |
| {question} | |
| μλλ λ΄κ° μΆλ‘ ν κ²°λ‘ μ λλ€: | |
| {reasoning_conclusion} | |
| μ μΆλ‘ μ κΈ°λ°μΌλ‘ μ΅μ’ λ΅λ³: | |
| {ANSWER_MARKER} | |
| """ | |
| # μμ νμ λ¬Έμ ν΄κ²°μ μν μ€μ | |
| latex_delimiters = [ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| ] | |
| def reformat_math(text): | |
| """Gradio ꡬ문(Katex)μ μ¬μ©νλλ‘ MathJax κ΅¬λΆ κΈ°νΈ μμ . | |
| μ΄κ²μ Gradioμμ μν 곡μμ νμνκΈ° μν μμ ν΄κ²°μ± μ λλ€. νμ¬λ‘μλ | |
| λ€λ₯Έ latex_delimitersλ₯Ό μ¬μ©νμ¬ μμλλ‘ μλνκ² νλ λ°©λ²μ μ°Ύμ§ λͺ»νμ΅λλ€... | |
| """ | |
| text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) | |
| text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) | |
| return text | |
| def user_input(message, history_original, history_thinking): | |
| """μ¬μ©μ μ λ ₯μ νμ€ν 리μ μΆκ°νκ³ μ λ ₯ ν μ€νΈ μμ λΉμ°κΈ°""" | |
| return "", history_original + [ | |
| gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, "")) | |
| ], history_thinking + [ | |
| gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, "")) | |
| ] | |
| def rebuild_messages(history: list): | |
| """μ€κ° μκ° κ³Όμ μμ΄ λͺ¨λΈμ΄ μ¬μ©ν νμ€ν 리μμ λ©μμ§ μ¬κ΅¬μ±""" | |
| messages = [] | |
| for h in history: | |
| if isinstance(h, dict) and not h.get("metadata", {}).get("title", False): | |
| messages.append(h) | |
| elif ( | |
| isinstance(h, gr.ChatMessage) | |
| and h.metadata.get("title", None) is None | |
| and isinstance(h.content, str) | |
| ): | |
| messages.append({"role": h.role, "content": h.content}) | |
| return messages | |
| def bot_original( | |
| history: list, | |
| max_num_tokens: int, | |
| do_sample: bool, | |
| temperature: float, | |
| ): | |
| """μλ³Έ λͺ¨λΈμ΄ μ§λ¬Έμ λ΅λ³νλλ‘ νκΈ° (μΆλ‘ κ³Όμ μμ΄)""" | |
| # λμ€μ μ€λ λμμ ν ν°μ μ€νΈλ¦ΌμΌλ‘ κ°μ Έμ€κΈ° μν¨ | |
| streamer = transformers.TextIteratorStreamer( | |
| pipe.tokenizer, # pyright: ignore | |
| skip_special_tokens=True, | |
| skip_prompt=True, | |
| ) | |
| # 보쑰μ λ©μμ§ μ€λΉ | |
| history.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=str(""), | |
| ) | |
| ) | |
| # νμ¬ μ±ν μ νμλ λ©μμ§ | |
| messages = rebuild_messages(history[:-1]) # λ§μ§λ§ λΉ λ©μμ§ μ μΈ | |
| # μλ³Έ λͺ¨λΈμ μΆλ‘ μμ΄ λ°λ‘ λ΅λ³ | |
| t = threading.Thread( | |
| target=pipe, | |
| args=(messages,), | |
| kwargs=dict( | |
| max_new_tokens=max_num_tokens, | |
| streamer=streamer, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| ), | |
| ) | |
| t.start() | |
| for token in streamer: | |
| history[-1].content += token | |
| history[-1].content = reformat_math(history[-1].content) | |
| yield history | |
| t.join() | |
| yield history | |
| def bot_thinking( | |
| history: list, | |
| max_num_tokens: int, | |
| final_num_tokens: int, | |
| do_sample: bool, | |
| temperature: float, | |
| ): | |
| """μΆλ‘ κ³Όμ μ ν¬ν¨νμ¬ λͺ¨λΈμ΄ μ§λ¬Έμ λ΅λ³νλλ‘ νκΈ°""" | |
| # λμ€μ μ€λ λμμ ν ν°μ μ€νΈλ¦ΌμΌλ‘ κ°μ Έμ€κΈ° μν¨ | |
| streamer = transformers.TextIteratorStreamer( | |
| pipe.tokenizer, # pyright: ignore | |
| skip_special_tokens=True, | |
| skip_prompt=True, | |
| ) | |
| # νμν κ²½μ° μΆλ‘ μ μ§λ¬Έμ λ€μ μ½μ νκΈ° μν¨ | |
| question = history[-1]["content"] | |
| # 보쑰μ λ©μμ§ μ€λΉ | |
| history.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=str(""), | |
| metadata={"title": "π§ μκ° μ€...", "status": "pending"}, | |
| ) | |
| ) | |
| # νμ¬ μ±ν μ νμλ μΆλ‘ κ³Όμ | |
| messages = rebuild_messages(history) | |
| # μ 체 μΆλ‘ κ³Όμ μ μ μ₯ν λ³μ | |
| full_reasoning = "" | |
| # μΆλ‘ λ¨κ³ μ€ν | |
| for i, prepend in enumerate(rethink_prepends): | |
| if i > 0: | |
| messages[-1]["content"] += "\n\n" | |
| messages[-1]["content"] += prepend.format(question=question) | |
| t = threading.Thread( | |
| target=pipe, | |
| args=(messages,), | |
| kwargs=dict( | |
| max_new_tokens=max_num_tokens, | |
| streamer=streamer, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| ), | |
| ) | |
| t.start() | |
| # μ λ΄μ©μΌλ‘ νμ€ν 리 μ¬κ΅¬μ± | |
| history[-1].content += prepend.format(question=question) | |
| for token in streamer: | |
| history[-1].content += token | |
| history[-1].content = reformat_math(history[-1].content) | |
| yield history | |
| t.join() | |
| # κ° μΆλ‘ λ¨κ³μ κ²°κ³Όλ₯Ό full_reasoningμ μ μ₯ | |
| full_reasoning = history[-1].content | |
| # μΆλ‘ μλ£, μ΄μ μ΅μ’ λ΅λ³μ μμ± | |
| history[-1].metadata = {"title": "π μ¬κ³ κ³Όμ ", "status": "done"} | |
| # μΆλ‘ κ³Όμ μμ κ²°λ‘ λΆλΆμ μΆμΆ (λ§μ§λ§ 1-2 λ¬Έλ¨ μ λ) | |
| reasoning_parts = full_reasoning.split("\n\n") | |
| reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning | |
| # μ΅μ’ λ΅λ³ λ©μμ§ μΆκ° | |
| history.append(gr.ChatMessage(role="assistant", content="")) | |
| # μ΅μ’ λ΅λ³μ μν λ©μμ§ κ΅¬μ± | |
| final_messages = rebuild_messages(history[:-1]) # λ§μ§λ§ λΉ λ©μμ§ μ μΈ | |
| final_prompt = final_answer_prompt.format( | |
| question=question, | |
| reasoning_conclusion=reasoning_conclusion, | |
| ANSWER_MARKER=ANSWER_MARKER | |
| ) | |
| final_messages[-1]["content"] += final_prompt | |
| # μ΅μ’ λ΅λ³ μμ± | |
| t = threading.Thread( | |
| target=pipe, | |
| args=(final_messages,), | |
| kwargs=dict( | |
| max_new_tokens=final_num_tokens, | |
| streamer=streamer, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| ), | |
| ) | |
| t.start() | |
| # μ΅μ’ λ΅λ³ μ€νΈλ¦¬λ° | |
| for token in streamer: | |
| history[-1].content += token | |
| history[-1].content = reformat_math(history[-1].content) | |
| yield history | |
| t.join() | |
| yield history | |
| with gr.Blocks(fill_height=True, title="Vidraft ThinkFlow") as demo: | |
| # μ λͺ©κ³Ό μ€λͺ | |
| gr.Markdown("# Vidraft ThinkFlow") | |
| gr.Markdown("### μΆλ‘ κΈ°λ₯μ΄ μλ LLM λͺ¨λΈμ μμ μμ΄λ μΆλ‘ κΈ°λ₯μ μλμΌλ‘ μ μ©νλ LLM μΆλ‘ μμ± νλ«νΌ") | |
| with gr.Row(scale=1): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Before (Original)") | |
| chatbot_original = gr.Chatbot( | |
| scale=1, | |
| type="messages", | |
| latex_delimiters=latex_delimiters, | |
| label="Original Model (No Reasoning)" | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("## After (Thinking)") | |
| chatbot_thinking = gr.Chatbot( | |
| scale=1, | |
| type="messages", | |
| latex_delimiters=latex_delimiters, | |
| label="Model with Reasoning" | |
| ) | |
| with gr.Row(): | |
| # msg ν μ€νΈλ°μ€λ₯Ό λ¨Όμ μ μ | |
| msg = gr.Textbox( | |
| submit_btn=True, | |
| label="", | |
| show_label=False, | |
| placeholder="μ¬κΈ°μ μ§λ¬Έμ μ λ ₯νμΈμ.", | |
| autofocus=True, | |
| ) | |
| # μμ μΉμ - msg λ³μ μ μ μ΄νμ λ°°μΉ | |
| with gr.Accordion("EXAMPLES", open=False): | |
| examples = gr.Examples( | |
| examples=[ | |
| "[μΆμ²: MATH-500)] μ²μ 100κ°μ μμ μ μ μ€μμ 3, 4, 5λ‘ λλμ΄ λ¨μ΄μ§λ μλ λͺ κ°μ λκΉ?", | |
| "[μΆμ²: MATH-500)] μν¬μ λ μμ λ μμ€ν μ λ νΉν©λλ€. νΈλ§ν· 1κ°λ λΈλ§ν· 4κ°μ κ°κ³ , λΈλ§ν· 3κ°λ λλ§ν¬ 7κ°μ κ°μ΅λλ€. νΈλ§ν·μμ λλ§ν¬ 56κ°μ κ°μΉλ μΌλ§μ λκΉ?", | |
| "[μΆμ²: MATH-500)] μμ΄λ―Έ, λ²€, ν¬λ¦¬μ€μ νκ· λμ΄λ 6μ΄μ λλ€. 4λ μ ν¬λ¦¬μ€λ μ§κΈ μμ΄λ―Έμ κ°μ λμ΄μμ΅λλ€. 4λ ν λ²€μ λμ΄λ κ·Έλ μμ΄λ―Έμ λμ΄μ $\\frac{3}{5}$κ° λ κ²μ λλ€. ν¬λ¦¬μ€λ μ§κΈ λͺ μ΄μ λκΉ?", | |
| "[μΆμ²: MATH-500)] λ Έλμκ³Ό νλμ ꡬμ¬μ΄ λ€μ΄ μλ κ°λ°©μ΄ μμ΅λλ€. νμ¬ νλμ ꡬμ¬κ³Ό λ Έλμ ꡬμ¬μ λΉμ¨μ 4:3μ λλ€. νλμ κ΅¬μ¬ 5κ°λ₯Ό λνκ³ λ Έλμ κ΅¬μ¬ 3κ°λ₯Ό μ κ±°νλ©΄ λΉμ¨μ 7:3μ΄ λ©λλ€. λ λ£κΈ° μ μ κ°λ°©μ νλμ ꡬμ¬μ΄ λͺ κ° μμμ΅λκΉ?" | |
| ], | |
| inputs=msg | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("""## λ§€κ°λ³μ μ‘°μ """) | |
| num_tokens = gr.Slider( | |
| 50, | |
| 4000, | |
| 2000, | |
| step=1, | |
| label="μΆλ‘ λ¨κ³λΉ μ΅λ ν ν° μ", | |
| interactive=True, | |
| ) | |
| final_num_tokens = gr.Slider( | |
| 50, | |
| 4000, | |
| 2000, | |
| step=1, | |
| label="μ΅μ’ λ΅λ³μ μ΅λ ν ν° μ", | |
| interactive=True, | |
| ) | |
| do_sample = gr.Checkbox(True, label="μνλ§ μ¬μ©") | |
| temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="μ¨λ") | |
| # μ¬μ©μκ° λ©μμ§λ₯Ό μ μΆνλ©΄ λ λ΄μ΄ λμμ μλ΅ν©λλ€ | |
| msg.submit( | |
| user_input, | |
| [msg, chatbot_original, chatbot_thinking], # μ λ ₯ | |
| [msg, chatbot_original, chatbot_thinking], # μΆλ ₯ | |
| ).then( | |
| bot_original, | |
| [ | |
| chatbot_original, | |
| num_tokens, | |
| do_sample, | |
| temperature, | |
| ], | |
| chatbot_original, # μΆλ ₯μμ μ νμ€ν 리 μ μ₯ | |
| ).then( | |
| bot_thinking, | |
| [ | |
| chatbot_thinking, | |
| num_tokens, | |
| final_num_tokens, | |
| do_sample, | |
| temperature, | |
| ], | |
| chatbot_thinking, # μΆλ ₯μμ μ νμ€ν 리 μ μ₯ | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |