Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| from pathlib import Path | |
| import spaces | |
| # === Import project modules === | |
| PROJECT_ROOT = Path(__file__).resolve().parent | |
| MMADA_ROOT = PROJECT_ROOT / "MMaDA" | |
| if str(MMADA_ROOT) not in sys.path: | |
| sys.path.insert(0, str(MMADA_ROOT)) | |
| from inference.gradio_multimodal_demo_inst import OmadaDemo | |
| import gradio as gr | |
| # ---------------------------------------------------------------------- | |
| # 1. Asset Loading (Downloaded by entrypoint) | |
| # ---------------------------------------------------------------------- | |
| ASSET_ROOT = PROJECT_ROOT / "_asset_cache" / "AIDAS-Omni-Modal-Diffusion-assets" | |
| DEMO_ROOT = ASSET_ROOT # asset repo already modality-split | |
| # ---------------------------------------------------------------------- | |
| # 2. GPU Handler Wrapper | |
| # ---------------------------------------------------------------------- | |
| def gpu_handler(fn): | |
| """ | |
| Wrap an inference function using ZeroGPU. | |
| """ | |
| def inner(*args, **kwargs): | |
| return fn(*args, **kwargs) | |
| return inner | |
| # ---------------------------------------------------------------------- | |
| # 3. Build Demo UI With Examples | |
| # ---------------------------------------------------------------------- | |
| def build_zero_gpu_demo(app: OmadaDemo): | |
| with gr.Blocks(title="AIDAS Omni-Modal Diffusion (ZeroGPU)") as demo: | |
| # ---------------- Header ---------------- | |
| gr.Markdown( | |
| "<h1 style='text-align:center'>AIDAS Omni-Modal Diffusion Model</h1>" | |
| ) | |
| try: | |
| logo_path = "/mnt/data/A2E36E9F-F389-487D-9984-FFF21C9228E3.png" | |
| gr.Image(logo_path, elem_id="logo", show_label=False, height=120) | |
| except: | |
| pass | |
| gr.Markdown("### Multimodal Inference Demo (ZeroGPU Optimized)") | |
| gr.Markdown("---") | |
| # ---------------- Tabs ---------------- | |
| with gr.Tabs(): | |
| # ============================================================ | |
| # 1) TEXT β SPEECH (T2S) | |
| # ============================================================ | |
| with gr.Tab("Text β Speech (T2S)"): | |
| t2s_in = gr.Textbox(label="Input Text") | |
| t2s_btn = gr.Button("Generate") | |
| t2s_audio = gr.Audio(label="Speech Output") | |
| t2s_status = gr.Textbox(label="Status", interactive=False) | |
| t2s_examples = [] | |
| t2s_dir = DEMO_ROOT / "t2s" | |
| if t2s_dir.exists(): | |
| for f in t2s_dir.glob("*.txt"): | |
| txt = f.read_text().strip() | |
| t2s_examples.append([txt]) | |
| if len(t2s_examples) > 0: | |
| gr.Examples( | |
| examples=t2s_examples, | |
| inputs=[t2s_in], | |
| outputs=[t2s_audio, t2s_status], | |
| fn=gpu_handler(app.run_t2s), | |
| ) | |
| t2s_btn.click( | |
| gpu_handler(app.run_t2s), | |
| inputs=[t2s_in], | |
| outputs=[t2s_audio, t2s_status], | |
| ) | |
| # ============================================================ | |
| # 2) SPEECH β SPEECH (S2S) | |
| # ============================================================ | |
| with gr.Tab("Speech β Speech (S2S)"): | |
| s2s_in = gr.Audio(type="filepath", label="Input Speech") | |
| s2s_btn = gr.Button("Generate") | |
| s2s_audio = gr.Audio(label="Output Speech") | |
| s2s_status = gr.Textbox(label="Status", interactive=False) | |
| s2s_examples = [] | |
| s2s_dir = DEMO_ROOT / "s2s" | |
| if s2s_dir.exists(): | |
| for f in s2s_dir.glob("*.wav"): | |
| s2s_examples.append([str(f)]) | |
| if len(s2s_examples) > 0: | |
| gr.Examples( | |
| examples=s2s_examples, | |
| inputs=[s2s_in], | |
| outputs=[s2s_audio, s2s_status], | |
| fn=gpu_handler(app.run_s2s), | |
| ) | |
| s2s_btn.click( | |
| gpu_handler(app.run_s2s), | |
| inputs=[s2s_in], | |
| outputs=[s2s_audio, s2s_status] | |
| ) | |
| # ============================================================ | |
| # 3) SPEECH β TEXT (S2T) | |
| # ============================================================ | |
| with gr.Tab("Speech β Text (S2T)"): | |
| s2t_in = gr.Audio(type="filepath", label="Input Speech") | |
| s2t_btn = gr.Button("Transcribe") | |
| s2t_text = gr.Textbox(label="Transcribed Text") | |
| s2t_status = gr.Textbox(label="Status", interactive=False) | |
| s2t_examples = [] | |
| s2t_dir = DEMO_ROOT / "s2t" | |
| if s2t_dir.exists(): | |
| for f in s2t_dir.glob("*.wav"): | |
| s2t_examples.append([str(f)]) | |
| if len(s2t_examples) > 0: | |
| gr.Examples( | |
| examples=s2t_examples, | |
| inputs=[s2t_in], | |
| outputs=[s2t_text, s2t_status], | |
| fn=gpu_handler(app.run_s2t), | |
| ) | |
| s2t_btn.click( | |
| gpu_handler(app.run_s2t), | |
| inputs=[s2t_in], | |
| outputs=[s2t_text, s2t_status], | |
| ) | |
| # ============================================================ | |
| # 4) VIDEO β TEXT (V2T) | |
| # ============================================================ | |
| with gr.Tab("Video β Text (V2T)"): | |
| v2t_in = gr.Video(type="filepath", label="Input Video") | |
| v2t_btn = gr.Button("Generate Caption") | |
| v2t_text = gr.Textbox(label="Caption") | |
| v2t_status = gr.Textbox(label="Status") | |
| v2t_examples = [] | |
| v2t_dir = DEMO_ROOT / "v2t" | |
| if v2t_dir.exists(): | |
| for f in v2t_dir.glob("*.mp4"): | |
| v2t_examples.append([str(f)]) | |
| if len(v2t_examples) > 0: | |
| gr.Examples( | |
| examples=v2t_examples, | |
| inputs=[v2t_in], | |
| outputs=[v2t_text, v2t_status], | |
| fn=gpu_handler(app.run_v2t), | |
| ) | |
| v2t_btn.click( | |
| gpu_handler(app.run_v2t), | |
| inputs=[v2t_in], | |
| outputs=[v2t_text, v2t_status], | |
| ) | |
| # ============================================================ | |
| # 5) VIDEO β SPEECH (V2S) | |
| # ============================================================ | |
| with gr.Tab("Video β Speech (V2S)"): | |
| v2s_in = gr.Video(type="filepath", label="Input Video") | |
| v2s_btn = gr.Button("Generate Speech") | |
| v2s_audio = gr.Audio(label="Speech Output") | |
| v2s_status = gr.Textbox(label="Status") | |
| v2s_examples = [] | |
| v2s_dir = DEMO_ROOT / "v2s" | |
| if v2s_dir.exists(): | |
| for f in v2s_dir.glob("*.mp4"): | |
| v2s_examples.append([str(f)]) | |
| if len(v2s_examples) > 0: | |
| gr.Examples( | |
| examples=v2s_examples, | |
| inputs=[v2s_in], | |
| outputs=[v2s_audio, v2s_status], | |
| fn=gpu_handler(app.run_v2s), | |
| ) | |
| v2s_btn.click( | |
| gpu_handler(app.run_v2s), | |
| inputs=[v2s_in], | |
| outputs=[v2s_audio, v2s_status], | |
| ) | |
| # ============================================================ | |
| # 6) IMAGE β SPEECH (I2S) | |
| # ============================================================ | |
| with gr.Tab("Image β Speech (I2S)"): | |
| i2s_in = gr.Image(type="filepath", label="Input Image") | |
| i2s_btn = gr.Button("Generate Speech") | |
| i2s_audio = gr.Audio(label="Speech") | |
| i2s_status = gr.Textbox(label="Status") | |
| # Only if folder exists | |
| i2s_examples = [] | |
| i2s_dir = DEMO_ROOT / "i2s" | |
| if i2s_dir.exists(): | |
| for f in i2s_dir.glob("*.*"): | |
| i2s_examples.append([str(f)]) | |
| if len(i2s_examples) > 0: | |
| gr.Examples( | |
| examples=i2s_examples, | |
| inputs=[i2s_in], | |
| outputs=[i2s_audio, i2s_status], | |
| fn=gpu_handler(app.run_i2s), | |
| ) | |
| i2s_btn.click( | |
| gpu_handler(app.run_i2s), | |
| inputs=[i2s_in], | |
| outputs=[i2s_audio, i2s_status], | |
| ) | |
| # ============================================================ | |
| # 7) CHAT | |
| # ============================================================ | |
| with gr.Tab("Chat (Text)"): | |
| chat_in = gr.Textbox(label="Message") | |
| chat_btn = gr.Button("Send") | |
| chat_out = gr.Textbox(label="Response") | |
| chat_status = gr.Textbox(label="Status") | |
| chat_examples = [] | |
| chat_dir = DEMO_ROOT / "chat" | |
| if chat_dir.exists(): | |
| for f in chat_dir.glob("*.txt"): | |
| txt = f.read_text().strip() | |
| chat_examples.append([txt]) | |
| if len(chat_examples) > 0: | |
| gr.Examples( | |
| examples=chat_examples, | |
| inputs=[chat_in], | |
| outputs=[chat_out, chat_status], | |
| fn=gpu_handler(app.run_chat), | |
| ) | |
| chat_btn.click( | |
| gpu_handler(app.run_chat), | |
| inputs=[chat_in], | |
| outputs=[chat_out, chat_status], | |
| ) | |
| # ============================================================ | |
| # 8) MMU (single image β text) | |
| # ============================================================ | |
| with gr.Tab("MMU (Image β Text)"): | |
| mmu_img = gr.Image(type="filepath", label="Input Image") | |
| mmu_prompt = gr.Textbox(label="Prompt") | |
| mmu_btn = gr.Button("Run MMU") | |
| mmu_out = gr.Textbox(label="Output") | |
| mmu_status = gr.Textbox(label="Status") | |
| mmu_examples = [] | |
| mmu_dir = DEMO_ROOT / "mmu" | |
| if mmu_dir.exists(): | |
| for f in mmu_dir.glob("*.png"): | |
| mmu_examples.append([ | |
| str(f), | |
| "Describe the main subject of this image." | |
| ]) | |
| if len(mmu_examples) > 0: | |
| gr.Examples( | |
| examples=mmu_examples, | |
| inputs=[mmu_img, mmu_prompt], | |
| outputs=[mmu_out, mmu_status], | |
| fn=gpu_handler(app.run_mmu), | |
| ) | |
| mmu_btn.click( | |
| gpu_handler(app.run_mmu), | |
| inputs=[mmu_img, mmu_prompt], | |
| outputs=[mmu_out, mmu_status] | |
| ) | |
| # ============================================================ | |
| # 9) TEXT β IMAGE (T2I) | |
| # ============================================================ | |
| with gr.Tab("Text β Image (T2I)"): | |
| t2i_in = gr.Textbox(label="Prompt") | |
| t2i_btn = gr.Button("Generate Image") | |
| t2i_img = gr.Image(label="Generated Image") | |
| t2i_status = gr.Textbox(label="Status") | |
| t2i_examples = [] | |
| t2i_dir = DEMO_ROOT / "t2i" | |
| if t2i_dir.exists(): | |
| for f in t2i_dir.glob("*.txt"): | |
| txt = f.read_text().strip() | |
| t2i_examples.append([txt]) | |
| if len(t2i_examples) > 0: | |
| gr.Examples( | |
| examples=t2i_examples, | |
| inputs=[t2i_in], | |
| outputs=[t2i_img, t2i_status], | |
| fn=gpu_handler(app.run_t2i), | |
| ) | |
| t2i_btn.click( | |
| gpu_handler(app.run_t2i), | |
| inputs=[t2i_in], | |
| outputs=[t2i_img, t2i_status], | |
| ) | |
| # ============================================================ | |
| # 10) IMAGE EDITING (I2I) | |
| # ============================================================ | |
| with gr.Tab("Image Editing (I2I)"): | |
| i2i_in = gr.Image(type="filepath", label="Input Image") | |
| i2i_prompt = gr.Textbox(label="Edit Instruction") | |
| i2i_btn = gr.Button("Apply Edit") | |
| i2i_img = gr.Image(label="Edited Image") | |
| i2i_status = gr.Textbox(label="Status") | |
| i2i_examples = [] | |
| i2i_dir = DEMO_ROOT / "i2i" | |
| if i2i_dir.exists(): | |
| for f in i2i_dir.glob("*.*"): | |
| i2i_examples.append([str(f), "Make it more vibrant."]) | |
| if len(i2i_examples) > 0: | |
| gr.Examples( | |
| examples=i2i_examples, | |
| inputs=[i2i_in, i2i_prompt], | |
| outputs=[i2i_img, i2i_status], | |
| fn=gpu_handler(app.run_i2i), | |
| ) | |
| i2i_btn.click( | |
| gpu_handler(app.run_i2i), | |
| inputs=[i2i_in, i2i_prompt], | |
| outputs=[i2i_img, i2i_status] | |
| ) | |
| # End Tabs | |
| return demo | |
| # ---------------------------------------------------------------------- | |
| # 4. Entry Point for Space | |
| # ---------------------------------------------------------------------- | |
| def main(): | |
| app = OmadaDemo( | |
| train_config=str(MMADA_ROOT / "inference/demo/demo.yaml"), | |
| checkpoint=os.getenv("MODEL_CHECKPOINT_DIR", "_ckpt_cache/omada"), | |
| device="cpu" | |
| ) | |
| demo = build_zero_gpu_demo(app) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |
| if __name__ == "__main__": | |
| main() | |