3v324v23's picture
..
e7c040d
import os
import sys
from pathlib import Path
import spaces
# === Import project modules ===
PROJECT_ROOT = Path(__file__).resolve().parent
MMADA_ROOT = PROJECT_ROOT / "MMaDA"
if str(MMADA_ROOT) not in sys.path:
sys.path.insert(0, str(MMADA_ROOT))
from inference.gradio_multimodal_demo_inst import OmadaDemo
import gradio as gr
# ----------------------------------------------------------------------
# 1. Asset Loading (Downloaded by entrypoint)
# ----------------------------------------------------------------------
ASSET_ROOT = PROJECT_ROOT / "_asset_cache" / "AIDAS-Omni-Modal-Diffusion-assets"
DEMO_ROOT = ASSET_ROOT # asset repo already modality-split
# ----------------------------------------------------------------------
# 2. GPU Handler Wrapper
# ----------------------------------------------------------------------
def gpu_handler(fn):
"""
Wrap an inference function using ZeroGPU.
"""
@spaces.GPU
def inner(*args, **kwargs):
return fn(*args, **kwargs)
return inner
# ----------------------------------------------------------------------
# 3. Build Demo UI With Examples
# ----------------------------------------------------------------------
def build_zero_gpu_demo(app: OmadaDemo):
with gr.Blocks(title="AIDAS Omni-Modal Diffusion (ZeroGPU)") as demo:
# ---------------- Header ----------------
gr.Markdown(
"<h1 style='text-align:center'>AIDAS Omni-Modal Diffusion Model</h1>"
)
try:
logo_path = "/mnt/data/A2E36E9F-F389-487D-9984-FFF21C9228E3.png"
gr.Image(logo_path, elem_id="logo", show_label=False, height=120)
except:
pass
gr.Markdown("### Multimodal Inference Demo (ZeroGPU Optimized)")
gr.Markdown("---")
# ---------------- Tabs ----------------
with gr.Tabs():
# ============================================================
# 1) TEXT β†’ SPEECH (T2S)
# ============================================================
with gr.Tab("Text β†’ Speech (T2S)"):
t2s_in = gr.Textbox(label="Input Text")
t2s_btn = gr.Button("Generate")
t2s_audio = gr.Audio(label="Speech Output")
t2s_status = gr.Textbox(label="Status", interactive=False)
t2s_examples = []
t2s_dir = DEMO_ROOT / "t2s"
if t2s_dir.exists():
for f in t2s_dir.glob("*.txt"):
txt = f.read_text().strip()
t2s_examples.append([txt])
if len(t2s_examples) > 0:
gr.Examples(
examples=t2s_examples,
inputs=[t2s_in],
outputs=[t2s_audio, t2s_status],
fn=gpu_handler(app.run_t2s),
)
t2s_btn.click(
gpu_handler(app.run_t2s),
inputs=[t2s_in],
outputs=[t2s_audio, t2s_status],
)
# ============================================================
# 2) SPEECH β†’ SPEECH (S2S)
# ============================================================
with gr.Tab("Speech β†’ Speech (S2S)"):
s2s_in = gr.Audio(type="filepath", label="Input Speech")
s2s_btn = gr.Button("Generate")
s2s_audio = gr.Audio(label="Output Speech")
s2s_status = gr.Textbox(label="Status", interactive=False)
s2s_examples = []
s2s_dir = DEMO_ROOT / "s2s"
if s2s_dir.exists():
for f in s2s_dir.glob("*.wav"):
s2s_examples.append([str(f)])
if len(s2s_examples) > 0:
gr.Examples(
examples=s2s_examples,
inputs=[s2s_in],
outputs=[s2s_audio, s2s_status],
fn=gpu_handler(app.run_s2s),
)
s2s_btn.click(
gpu_handler(app.run_s2s),
inputs=[s2s_in],
outputs=[s2s_audio, s2s_status]
)
# ============================================================
# 3) SPEECH β†’ TEXT (S2T)
# ============================================================
with gr.Tab("Speech β†’ Text (S2T)"):
s2t_in = gr.Audio(type="filepath", label="Input Speech")
s2t_btn = gr.Button("Transcribe")
s2t_text = gr.Textbox(label="Transcribed Text")
s2t_status = gr.Textbox(label="Status", interactive=False)
s2t_examples = []
s2t_dir = DEMO_ROOT / "s2t"
if s2t_dir.exists():
for f in s2t_dir.glob("*.wav"):
s2t_examples.append([str(f)])
if len(s2t_examples) > 0:
gr.Examples(
examples=s2t_examples,
inputs=[s2t_in],
outputs=[s2t_text, s2t_status],
fn=gpu_handler(app.run_s2t),
)
s2t_btn.click(
gpu_handler(app.run_s2t),
inputs=[s2t_in],
outputs=[s2t_text, s2t_status],
)
# ============================================================
# 4) VIDEO β†’ TEXT (V2T)
# ============================================================
with gr.Tab("Video β†’ Text (V2T)"):
v2t_in = gr.Video(type="filepath", label="Input Video")
v2t_btn = gr.Button("Generate Caption")
v2t_text = gr.Textbox(label="Caption")
v2t_status = gr.Textbox(label="Status")
v2t_examples = []
v2t_dir = DEMO_ROOT / "v2t"
if v2t_dir.exists():
for f in v2t_dir.glob("*.mp4"):
v2t_examples.append([str(f)])
if len(v2t_examples) > 0:
gr.Examples(
examples=v2t_examples,
inputs=[v2t_in],
outputs=[v2t_text, v2t_status],
fn=gpu_handler(app.run_v2t),
)
v2t_btn.click(
gpu_handler(app.run_v2t),
inputs=[v2t_in],
outputs=[v2t_text, v2t_status],
)
# ============================================================
# 5) VIDEO β†’ SPEECH (V2S)
# ============================================================
with gr.Tab("Video β†’ Speech (V2S)"):
v2s_in = gr.Video(type="filepath", label="Input Video")
v2s_btn = gr.Button("Generate Speech")
v2s_audio = gr.Audio(label="Speech Output")
v2s_status = gr.Textbox(label="Status")
v2s_examples = []
v2s_dir = DEMO_ROOT / "v2s"
if v2s_dir.exists():
for f in v2s_dir.glob("*.mp4"):
v2s_examples.append([str(f)])
if len(v2s_examples) > 0:
gr.Examples(
examples=v2s_examples,
inputs=[v2s_in],
outputs=[v2s_audio, v2s_status],
fn=gpu_handler(app.run_v2s),
)
v2s_btn.click(
gpu_handler(app.run_v2s),
inputs=[v2s_in],
outputs=[v2s_audio, v2s_status],
)
# ============================================================
# 6) IMAGE β†’ SPEECH (I2S)
# ============================================================
with gr.Tab("Image β†’ Speech (I2S)"):
i2s_in = gr.Image(type="filepath", label="Input Image")
i2s_btn = gr.Button("Generate Speech")
i2s_audio = gr.Audio(label="Speech")
i2s_status = gr.Textbox(label="Status")
# Only if folder exists
i2s_examples = []
i2s_dir = DEMO_ROOT / "i2s"
if i2s_dir.exists():
for f in i2s_dir.glob("*.*"):
i2s_examples.append([str(f)])
if len(i2s_examples) > 0:
gr.Examples(
examples=i2s_examples,
inputs=[i2s_in],
outputs=[i2s_audio, i2s_status],
fn=gpu_handler(app.run_i2s),
)
i2s_btn.click(
gpu_handler(app.run_i2s),
inputs=[i2s_in],
outputs=[i2s_audio, i2s_status],
)
# ============================================================
# 7) CHAT
# ============================================================
with gr.Tab("Chat (Text)"):
chat_in = gr.Textbox(label="Message")
chat_btn = gr.Button("Send")
chat_out = gr.Textbox(label="Response")
chat_status = gr.Textbox(label="Status")
chat_examples = []
chat_dir = DEMO_ROOT / "chat"
if chat_dir.exists():
for f in chat_dir.glob("*.txt"):
txt = f.read_text().strip()
chat_examples.append([txt])
if len(chat_examples) > 0:
gr.Examples(
examples=chat_examples,
inputs=[chat_in],
outputs=[chat_out, chat_status],
fn=gpu_handler(app.run_chat),
)
chat_btn.click(
gpu_handler(app.run_chat),
inputs=[chat_in],
outputs=[chat_out, chat_status],
)
# ============================================================
# 8) MMU (single image β†’ text)
# ============================================================
with gr.Tab("MMU (Image β†’ Text)"):
mmu_img = gr.Image(type="filepath", label="Input Image")
mmu_prompt = gr.Textbox(label="Prompt")
mmu_btn = gr.Button("Run MMU")
mmu_out = gr.Textbox(label="Output")
mmu_status = gr.Textbox(label="Status")
mmu_examples = []
mmu_dir = DEMO_ROOT / "mmu"
if mmu_dir.exists():
for f in mmu_dir.glob("*.png"):
mmu_examples.append([
str(f),
"Describe the main subject of this image."
])
if len(mmu_examples) > 0:
gr.Examples(
examples=mmu_examples,
inputs=[mmu_img, mmu_prompt],
outputs=[mmu_out, mmu_status],
fn=gpu_handler(app.run_mmu),
)
mmu_btn.click(
gpu_handler(app.run_mmu),
inputs=[mmu_img, mmu_prompt],
outputs=[mmu_out, mmu_status]
)
# ============================================================
# 9) TEXT β†’ IMAGE (T2I)
# ============================================================
with gr.Tab("Text β†’ Image (T2I)"):
t2i_in = gr.Textbox(label="Prompt")
t2i_btn = gr.Button("Generate Image")
t2i_img = gr.Image(label="Generated Image")
t2i_status = gr.Textbox(label="Status")
t2i_examples = []
t2i_dir = DEMO_ROOT / "t2i"
if t2i_dir.exists():
for f in t2i_dir.glob("*.txt"):
txt = f.read_text().strip()
t2i_examples.append([txt])
if len(t2i_examples) > 0:
gr.Examples(
examples=t2i_examples,
inputs=[t2i_in],
outputs=[t2i_img, t2i_status],
fn=gpu_handler(app.run_t2i),
)
t2i_btn.click(
gpu_handler(app.run_t2i),
inputs=[t2i_in],
outputs=[t2i_img, t2i_status],
)
# ============================================================
# 10) IMAGE EDITING (I2I)
# ============================================================
with gr.Tab("Image Editing (I2I)"):
i2i_in = gr.Image(type="filepath", label="Input Image")
i2i_prompt = gr.Textbox(label="Edit Instruction")
i2i_btn = gr.Button("Apply Edit")
i2i_img = gr.Image(label="Edited Image")
i2i_status = gr.Textbox(label="Status")
i2i_examples = []
i2i_dir = DEMO_ROOT / "i2i"
if i2i_dir.exists():
for f in i2i_dir.glob("*.*"):
i2i_examples.append([str(f), "Make it more vibrant."])
if len(i2i_examples) > 0:
gr.Examples(
examples=i2i_examples,
inputs=[i2i_in, i2i_prompt],
outputs=[i2i_img, i2i_status],
fn=gpu_handler(app.run_i2i),
)
i2i_btn.click(
gpu_handler(app.run_i2i),
inputs=[i2i_in, i2i_prompt],
outputs=[i2i_img, i2i_status]
)
# End Tabs
return demo
# ----------------------------------------------------------------------
# 4. Entry Point for Space
# ----------------------------------------------------------------------
@spaces.GPU
def main():
app = OmadaDemo(
train_config=str(MMADA_ROOT / "inference/demo/demo.yaml"),
checkpoint=os.getenv("MODEL_CHECKPOINT_DIR", "_ckpt_cache/omada"),
device="cpu"
)
demo = build_zero_gpu_demo(app)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
if __name__ == "__main__":
main()