""" ZeroGPU-friendly Gradio entrypoint for OMada demo. - Downloads checkpoint + assets + style centroids from Hugging Face Hub - Instantiates OmadaDemo once (global) - Exposes 10 modalities via Gradio tabs - Uses @spaces.GPU only on inference handlers so GPU is allocated per request """ import os import sys import subprocess import importlib from pathlib import Path from typing import List import gradio as gr import spaces from packaging.version import parse as parse_version # --------------------------- # Project roots & sys.path # --------------------------- PROJECT_ROOT = Path(__file__).resolve().parent MMADA_ROOT = PROJECT_ROOT / "MMaDA" if str(MMADA_ROOT) not in sys.path: sys.path.insert(0, str(MMADA_ROOT)) EMOVA_ROOT = PROJECT_ROOT / "EMOVA_speech_tokenizer" if str(EMOVA_ROOT) not in sys.path: sys.path.insert(0, str(EMOVA_ROOT)) # --------------------------- # HuggingFace Hub helper # --------------------------- def ensure_hf_hub(target: str = "0.36.0"): """ Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers. """ try: import huggingface_hub as hub except ImportError: subprocess.check_call( [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"] ) import huggingface_hub as hub if parse_version(hub.__version__) >= parse_version("1.0.0"): subprocess.check_call( [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"] ) hub = importlib.reload(hub) # Backfill missing constants in older hub versions to avoid AttributeError. try: import huggingface_hub.constants as hub_consts # type: ignore except Exception: hub_consts = None if hub_consts and not hasattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER"): setattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER", False) return hub snapshot_download = ensure_hf_hub().snapshot_download # --------------------------- # OMada demo imports # --------------------------- from inference.gradio_multimodal_demo_inst import ( # noqa: E402 OmadaDemo, CUSTOM_CSS, FORCE_LIGHT_MODE_JS, ) # --------------------------- # HF download helpers # --------------------------- def download_assets() -> Path: """Download demo assets (logo + sample prompts/media) and return the root path.""" repo_id = os.getenv("ASSET_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion-assets") revision = os.getenv("ASSET_REVISION", "main") token = os.getenv("HF_TOKEN") cache_dir = PROJECT_ROOT / "_asset_cache" cache_dir.mkdir(parents=True, exist_ok=True) return Path( snapshot_download( repo_id=repo_id, revision=revision, repo_type="dataset", local_dir=cache_dir, local_dir_use_symlinks=False, token=token, ) ) def download_style() -> Path: """Download style centroid dataset and return the root path.""" repo_id = os.getenv("STYLE_REPO_ID", "jaeikkim/aidas-style-centroid") revision = os.getenv("STYLE_REVISION", "main") token = os.getenv("HF_TOKEN") cache_dir = PROJECT_ROOT / "_style_cache" cache_dir.mkdir(parents=True, exist_ok=True) return Path( snapshot_download( repo_id=repo_id, revision=revision, repo_type="dataset", local_dir=cache_dir, local_dir_use_symlinks=False, token=token, ) ) def download_checkpoint() -> Path: """Download checkpoint snapshot and return an `unwrapped_model` directory.""" local_override = os.getenv("MODEL_CHECKPOINT_PATH") if local_override: override_path = Path(local_override).expanduser() if override_path.name != "unwrapped_model": nested = override_path / "unwrapped_model" if nested.is_dir(): override_path = nested if not override_path.exists(): raise FileNotFoundError(f"MODEL_CHECKPOINT_PATH does not exist: {override_path}") return override_path repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion") revision = os.getenv("MODEL_REVISION", "main") token = os.getenv("HF_TOKEN") cache_dir = PROJECT_ROOT / "_ckpt_cache" cache_dir.mkdir(parents=True, exist_ok=True) snapshot_path = Path( snapshot_download( repo_id=repo_id, revision=revision, repo_type="model", local_dir=cache_dir, local_dir_use_symlinks=False, token=token, ) ) if snapshot_path.name == "unwrapped_model": return snapshot_path nested = snapshot_path / "unwrapped_model" if nested.is_dir(): return nested aliased = snapshot_path.parent / "unwrapped_model" if not aliased.exists(): aliased.symlink_to(snapshot_path, target_is_directory=True) return aliased # --------------------------- # Assets (for examples + logo) # --------------------------- ASSET_ROOT = download_assets() STYLE_ROOT = download_style() LOGO_PATH = ASSET_ROOT / "logo.png" # optional def _load_text_examples(path: Path): if not path.exists(): return [] lines = [ ln.strip() for ln in path.read_text(encoding="utf-8").splitlines() if ln.strip() ] return [[ln] for ln in lines] def _load_media_examples(subdir: str, suffixes): d = ASSET_ROOT / subdir if not d.exists(): return [] ex = [] for p in sorted(d.iterdir()): if p.is_file() and p.suffix.lower() in suffixes: ex.append([str(p)]) return ex def _load_i2i_examples(): d = ASSET_ROOT / "i2i" if not d.exists(): return [] # 이미지 파일들 (image1.jpeg, image2.png, ...) image_files = sorted( [p for p in d.iterdir() if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}] ) # 텍스트 파일들 (text1.txt, text2.txt, ...) text_files = sorted( [p for p in d.iterdir() if p.suffix.lower() == ".txt"] ) n = min(len(image_files), len(text_files)) examples = [] for i in range(2): img_path = image_files[i] txt_path = text_files[i] instruction = txt_path.read_text(encoding="utf-8").strip() if not instruction: continue # Gradio Examples 형식: [image, instruction_text] examples.append([str(img_path), instruction]) return examples def _load_ti2ti_examples(): """Load TI2TI examples: pairs of source image + instruction text.""" d = ASSET_ROOT / "ti2ti" if not d.exists(): return [] src_files = sorted( [p for p in d.iterdir() if p.is_file() and p.name.endswith("_src.png")], ) txt_files = {p.name.replace("_instr.txt", ""): p for p in d.iterdir() if p.is_file() and p.name.endswith("_instr.txt")} examples = [] for src in src_files: stem = src.name.replace("_src.png", "") txt = txt_files.get(stem) if not txt: continue instruction = txt.read_text(encoding="utf-8").strip() if not instruction: continue examples.append([str(src), instruction]) return examples # text-based examples T2S_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2s" / "text.txt") CHAT_EXAMPLES = _load_text_examples(ASSET_ROOT / "chat" / "text.txt") T2I_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2i" / "text.txt") I2I_EXAMPLES = _load_i2i_examples() TI2TI_EXAMPLES = _load_ti2ti_examples() # audio / video / image examples S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"}) S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"}) V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) V2S_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) # MMU images (and fallback for I2S) MMU_DIR = ASSET_ROOT / "mmu" MMU_EXAMPLES: List[List[str]] = [] if MMU_DIR.exists(): for path in sorted( [ p for p in MMU_DIR.iterdir() if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"} ] ): MMU_EXAMPLES.append([ str(path), "Describe the important objects and their relationships in this image.", ]) I2S_EXAMPLES = _load_media_examples("i2s", {".png", ".jpg", ".jpeg", ".webp"}) if not I2S_EXAMPLES and MMU_EXAMPLES: # use the first MMU sample image if no dedicated I2S example exists I2S_EXAMPLES = [[MMU_EXAMPLES[0][0]]] # --------------------------- # Global OmadaDemo instance # --------------------------- APP = None # type: ignore def get_app() -> OmadaDemo: global APP if APP is not None: return APP ckpt_dir = download_checkpoint() # Wire style centroids to expected locations style_targets = [ MMADA_ROOT / "models" / "speech_tokenization" / "condition_style_centroid", PROJECT_ROOT / "EMOVA_speech_tokenizer" / "emova_speech_tokenizer" / "speech_tokenization" / "condition_style_centroid", ] for starget in style_targets: if not starget.exists(): starget.parent.mkdir(parents=True, exist_ok=True) starget.symlink_to(STYLE_ROOT, target_is_directory=True) default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml" legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml" train_config = os.getenv("TRAIN_CONFIG_PATH") if not train_config: train_config = str(default_cfg if default_cfg.exists() else legacy_cfg) device = os.getenv("DEVICE", "cuda") APP = OmadaDemo(train_config=train_config, checkpoint=str(ckpt_dir), device=device) return APP # --------------------------- # ZeroGPU-wrapped handlers # --------------------------- # (== 그대로, 생략 없이 둔 부분 ==) @spaces.GPU def t2s_handler(text, max_tokens, steps, block_len, temperature, cfg_scale, gender, emotion, speed, pitch): app = get_app() audio, status = app.run_t2s( text=text, max_new_tokens=int(max_tokens), steps=int(steps), block_length=int(block_len), temperature=float(temperature), cfg_scale=float(cfg_scale), gender_choice=gender, emotion_choice=emotion, speed_choice=speed, pitch_choice=pitch, ) return audio, status @spaces.GPU def s2s_handler(audio_path, max_tokens, steps, block_len, temperature, cfg_scale): app = get_app() audio, status = app.run_s2s( audio_path=audio_path, max_new_tokens=int(max_tokens), steps=int(steps), block_length=int(block_len), temperature=float(temperature), cfg_scale=float(cfg_scale), ) return audio, status @spaces.GPU def s2t_handler(audio_path, steps, block_len, max_tokens, remasking): app = get_app() text, status = app.run_s2t( audio_path=audio_path, steps=int(steps), block_length=int(block_len), max_new_tokens=int(max_tokens), remasking=str(remasking), ) return text, status @spaces.GPU def v2t_handler(video, steps, block_len, max_tokens): app = get_app() text, status = app.run_v2t( video_path=video, steps=int(steps), block_length=int(block_len), max_new_tokens=int(max_tokens), ) return text, status @spaces.GPU def v2s_handler(video, message, max_tokens, steps, block_len, temperature, cfg_scale): app = get_app() audio, status = app.run_v2s( video_path=video, message=message, max_new_tokens=int(max_tokens), steps=int(steps), block_length=int(block_len), temperature=float(temperature), cfg_scale=float(cfg_scale), ) return audio, status @spaces.GPU def i2s_handler(image, message, max_tokens, steps, block_len, temperature, cfg_scale): app = get_app() audio, status = app.run_i2s( image=image, message=message, max_new_tokens=int(max_tokens), steps=int(steps), block_length=int(block_len), temperature=float(temperature), cfg_scale=float(cfg_scale), ) return audio, status @spaces.GPU def chat_handler(message, max_tokens, steps, block_len, temperature): app = get_app() text, status = app.run_chat( message=message, max_new_tokens=int(max_tokens), steps=int(steps), block_length=int(block_len), temperature=float(temperature), ) return text, status @spaces.GPU def mmu_handler(image, question, max_tokens, steps, block_len, temperature): app = get_app() text, status = app.run_mmu( images=image, message=question, max_new_tokens=int(max_tokens), steps=int(steps), block_length=int(block_len), temperature=float(temperature), ) return text, status @spaces.GPU def t2i_handler(prompt, timesteps, temperature, guidance): app = get_app() image, status = app.run_t2i( prompt=prompt, timesteps=int(timesteps), temperature=float(temperature), guidance_scale=float(guidance), ) return image, status @spaces.GPU def i2i_handler(instruction, image, timesteps, temperature, guidance): app = get_app() image_out, status = app.run_i2i( instruction=instruction, source_image=image, timesteps=int(timesteps), temperature=float(temperature), guidance_scale=float(guidance), ) return image_out, status @spaces.GPU def ti2ti_handler(instruction, image, text_tokens, timesteps_image, timesteps_text, temperature, guidance): app = get_app() image_out, text_out, status = app.run_ti2ti( instruction=instruction, source_image=image, text_tokens=int(text_tokens), timesteps_image=int(timesteps_image), timesteps_text=int(timesteps_text), temperature=float(temperature), guidance_scale=float(guidance), ) return image_out, text_out, status # --------------------------- # Gradio UI (10 tabs + examples) # --------------------------- theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray") with gr.Blocks( title="AIDAS Lab @ SNU - Omni-modal Diffusion", css=CUSTOM_CSS, theme=theme, js=FORCE_LIGHT_MODE_JS, ) as demo: with gr.Row(): if LOGO_PATH.exists(): gr.Image( value=str(LOGO_PATH), show_label=False, height=80, interactive=False, ) gr.Markdown( "## Omni-modal Diffusion Foundation Model\n" "### AIDAS Lab @ SNU" ) # ---- T2S ---- with gr.Tab("Text → Speech (T2S)"): with gr.Row(): t2s_text = gr.Textbox( label="Input text", lines=4, placeholder="Type the speech you want to synthesize...", ) t2s_audio = gr.Audio(label="Generated speech", type="numpy") t2s_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): t2s_max_tokens = gr.Slider(2, 512, value=384, step=2, label="Speech token length") t2s_steps = gr.Slider(2, 512, value=128, step=2, label="Total refinement steps") t2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length") t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, step=0.1, label="CFG scale") with gr.Row(): t2s_gender = gr.Dropdown(["random", "female", "male"], value="random", label="Gender") t2s_emotion = gr.Dropdown(["random", "angry", "happy", "neutral", "sad"], value="random", label="Emotion") with gr.Row(): t2s_speed = gr.Dropdown(["random", "normal", "fast", "slow"], value="random", label="Speed") t2s_pitch = gr.Dropdown(["random", "normal", "high", "low"], value="random", label="Pitch") if T2S_EXAMPLES: with gr.Accordion("Sample prompts", open=False): gr.Examples( examples=T2S_EXAMPLES, inputs=[t2s_text], examples_per_page=6, ) t2s_btn = gr.Button("Generate speech", variant="primary") t2s_btn.click( t2s_handler, inputs=[ t2s_text, t2s_max_tokens, t2s_steps, t2s_block, t2s_temperature, t2s_cfg, t2s_gender, t2s_emotion, t2s_speed, t2s_pitch, ], outputs=[t2s_audio, t2s_status], ) # ---- S2S ---- with gr.Tab("Speech → Speech (S2S)"): s2s_audio_in = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"]) s2s_audio_out = gr.Audio(type="numpy", label="Reply speech") s2s_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): s2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length") s2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps") s2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length") s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="Sampling temperature") s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, step=0.1, label="CFG scale") if S2S_EXAMPLES: with gr.Accordion("Sample clips", open=False): gr.Examples( examples=S2S_EXAMPLES, inputs=[s2s_audio_in], examples_per_page=4, ) s2s_btn = gr.Button("Generate reply speech", variant="primary") s2s_btn.click( s2s_handler, inputs=[ s2s_audio_in, s2s_max_tokens, s2s_steps, s2s_block, s2s_temperature, s2s_cfg, ], outputs=[s2s_audio_out, s2s_status], ) # ---- S2T ---- with gr.Tab("Speech → Text (S2T)"): s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"]) s2t_text_out = gr.Textbox(label="Transcription", lines=4) s2t_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): s2t_steps = gr.Slider(2, 512, value=128, step=2, label="Denoising steps") s2t_block = gr.Slider(2, 512, value=128, step=2, label="Block length") s2t_max_tokens = gr.Slider(2, 512, value=128, step=2, label="Max new tokens") s2t_remasking = gr.Dropdown( ["low_confidence", "random"], value="low_confidence", label="Remasking strategy", ) if S2T_EXAMPLES: with gr.Accordion("Sample clips", open=False): gr.Examples( examples=S2T_EXAMPLES, inputs=[s2t_audio_in], examples_per_page=4, ) s2t_btn = gr.Button("Transcribe", variant="primary") s2t_btn.click( s2t_handler, inputs=[s2t_audio_in, s2t_steps, s2t_block, s2t_max_tokens, s2t_remasking], outputs=[s2t_text_out, s2t_status], ) # ---- V2T ---- with gr.Tab("Video → Text (V2T)"): v2t_video_in = gr.Video( label="Upload or record video", height=256, sources=["upload", "webcam"], ) v2t_text_out = gr.Textbox(label="Caption / answer", lines=4) v2t_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): v2t_steps = gr.Slider(2, 512, value=64, step=2, label="Denoising steps") v2t_block = gr.Slider(2, 512, value=64, step=2, label="Block length") v2t_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Max new tokens") if V2T_EXAMPLES: with gr.Accordion("Sample videos", open=False): gr.Examples( examples=V2T_EXAMPLES, inputs=[v2t_video_in], examples_per_page=4, ) v2t_btn = gr.Button("Generate caption", variant="primary") v2t_btn.click( v2t_handler, inputs=[v2t_video_in, v2t_steps, v2t_block, v2t_max_tokens], outputs=[v2t_text_out, v2t_status], ) # ---- V2S ---- with gr.Tab("Video → Speech (V2S)"): v2s_video_in = gr.Video( label="Upload or record video", height=256, sources=["upload", "webcam"], ) v2s_prompt = gr.Textbox( label="Optional instruction", placeholder="(Optional) e.g., 'Describe this scene in spoken form.'", ) v2s_audio_out = gr.Audio(type="numpy", label="Generated speech") v2s_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): v2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length") v2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps") v2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length") v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale") # (optional v2s examples: if you later add 'v2s' folder, same 패턴으로 붙이면 됨) if V2T_EXAMPLES: with gr.Accordion("Sample videos", open=False): gr.Examples( examples=V2T_EXAMPLES, inputs=[v2t_video_in], examples_per_page=4, ) v2s_btn = gr.Button("Generate speech from video", variant="primary") v2s_btn.click( v2s_handler, inputs=[ v2s_video_in, v2s_prompt, v2s_max_tokens, v2s_steps, v2s_block, v2s_temperature, v2s_cfg, ], outputs=[v2s_audio_out, v2s_status], ) # ---- T2I ---- with gr.Tab("Text → Image (T2I)"): t2i_prompt = gr.Textbox( label="Prompt", lines=4, placeholder="Describe the image you want to generate...", ) t2i_image_out = gr.Image(label="Generated image") t2i_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps") t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale") if T2I_EXAMPLES: with gr.Accordion("Sample prompts", open=False): gr.Examples( examples=T2I_EXAMPLES, inputs=[t2i_prompt], examples_per_page=6, ) t2i_btn = gr.Button("Generate image", variant="primary") t2i_btn.click( t2i_handler, inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance], outputs=[t2i_image_out, t2i_status], ) # ---- I2I ---- with gr.Tab("Image Editing (I2I)"): i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"]) i2i_instr = gr.Textbox( label="Editing instruction", lines=4, placeholder="Describe how you want to edit the image...", ) i2i_image_out = gr.Image(label="Edited image") i2i_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): i2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps") i2i_temperature = gr.Slider(0.0, 2.0, value=0.3, step=0.05, label="Sampling temperature") i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale") if I2I_EXAMPLES: with gr.Accordion("Sample edits", open=False): gr.Examples( examples=I2I_EXAMPLES, inputs=[i2i_image_in, i2i_instr], examples_per_page=4, ) i2i_btn = gr.Button("Apply edit", variant="primary") i2i_btn.click( i2i_handler, inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance], outputs=[i2i_image_out, i2i_status], ) # ---- TI2TI ---- with gr.Tab("Text+Image → Text+Image (TI2TI)"): ti2ti_image_in = gr.Image(type="pil", label="Source image", sources=["upload"]) ti2ti_instr = gr.Textbox( label="Editing instruction", lines=4, placeholder="Describe how you want the image edited and what to say about it...", ) ti2ti_image_out = gr.Image(label="Edited image") ti2ti_text_out = gr.Textbox(label="Generated text", lines=4) ti2ti_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): ti2ti_text_tokens = gr.Slider(8, 256, value=64, step=4, label="Text placeholder tokens") ti2ti_img_steps = gr.Slider(4, 128, value=64, step=2, label="Image timesteps") ti2ti_text_steps = gr.Slider(4, 128, value=64, step=2, label="Text timesteps") ti2ti_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") ti2ti_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale") if TI2TI_EXAMPLES: with gr.Accordion("Sample edits", open=False): gr.Examples( examples=TI2TI_EXAMPLES, inputs=[ti2ti_image_in, ti2ti_instr], examples_per_page=4, ) ti2ti_btn = gr.Button("Generate edited image + text", variant="primary") ti2ti_btn.click( ti2ti_handler, inputs=[ ti2ti_instr, ti2ti_image_in, ti2ti_text_tokens, ti2ti_img_steps, ti2ti_text_steps, ti2ti_temperature, ti2ti_guidance, ], outputs=[ti2ti_image_out, ti2ti_text_out, ti2ti_status], ) # ---- I2S ---- with gr.Tab("Image → Speech (I2S)"): i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"]) i2s_prompt = gr.Textbox( label="Optional question", placeholder="(Optional) e.g., 'Describe this image aloud.'", ) i2s_audio_out = gr.Audio(type="numpy", label="Spoken description") i2s_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): i2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length") i2s_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps") i2s_block = gr.Slider(2, 512, value=256, step=2, label="Block length") i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale") if I2S_EXAMPLES: with gr.Accordion("Sample images", open=False): gr.Examples( examples=I2S_EXAMPLES, inputs=[i2s_image_in], examples_per_page=4, ) i2s_btn = gr.Button("Generate spoken description", variant="primary") i2s_btn.click( i2s_handler, inputs=[ i2s_image_in, i2s_prompt, i2s_max_tokens, i2s_steps, i2s_block, i2s_temperature, i2s_cfg, ], outputs=[i2s_audio_out, i2s_status], ) # ---- Chat ---- with gr.Tab("Text Chat"): chat_in = gr.Textbox( label="Message", lines=4, placeholder="Ask anything. The model will reply in text.", ) chat_out = gr.Textbox(label="Assistant reply", lines=6) chat_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): chat_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Reply max tokens") chat_steps = gr.Slider(2, 512, value=64, step=2, label="Refinement steps") chat_block = gr.Slider(2, 512, value=64, step=2, label="Block length") chat_temperature_slider = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Sampling temperature") if CHAT_EXAMPLES: with gr.Accordion("Sample prompts", open=False): gr.Examples( examples=CHAT_EXAMPLES, inputs=[chat_in], examples_per_page=6, ) chat_btn = gr.Button("Send", variant="primary") chat_btn.click( chat_handler, inputs=[ chat_in, chat_max_tokens, chat_steps, chat_block, chat_temperature_slider, ], outputs=[chat_out, chat_status], ) # ---- MMU ---- with gr.Tab("MMU (Image → Text)"): mmu_img = gr.Image(type="pil", label="Input image", sources=["upload"]) mmu_question = gr.Textbox( label="Question", lines=3, placeholder="Ask about the scene, objects, or context of the image.", ) mmu_answer = gr.Textbox(label="Answer", lines=6) mmu_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced settings", open=False): mmu_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Answer max tokens") mmu_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps") mmu_block = gr.Slider(2, 512, value=128, step=2, label="Block length") mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Sampling temperature") if MMU_EXAMPLES: with gr.Accordion("Sample MMU prompts", open=False): gr.Examples( examples=MMU_EXAMPLES, inputs=[mmu_img, mmu_question], examples_per_page=1, ) mmu_btn = gr.Button("Answer about the image", variant="primary") mmu_btn.click( mmu_handler, inputs=[ mmu_img, mmu_question, mmu_max_tokens, mmu_steps, mmu_block, mmu_temperature, ], outputs=[mmu_answer, mmu_status], ) if __name__ == "__main__": demo.launch()