Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ZeroGPU-friendly Gradio entrypoint for OMada demo. | |
| - Downloads checkpoint + assets + style centroids from Hugging Face Hub | |
| - Instantiates OmadaDemo once (global) | |
| - Exposes 10 modalities via Gradio tabs | |
| - Uses @spaces.GPU only on inference handlers so GPU is allocated per request | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import importlib | |
| from pathlib import Path | |
| from typing import List | |
| import gradio as gr | |
| import spaces | |
| from packaging.version import parse as parse_version | |
| # --------------------------- | |
| # Project roots & sys.path | |
| # --------------------------- | |
| PROJECT_ROOT = Path(__file__).resolve().parent | |
| MMADA_ROOT = PROJECT_ROOT / "MMaDA" | |
| if str(MMADA_ROOT) not in sys.path: | |
| sys.path.insert(0, str(MMADA_ROOT)) | |
| EMOVA_ROOT = PROJECT_ROOT / "EMOVA_speech_tokenizer" | |
| if str(EMOVA_ROOT) not in sys.path: | |
| sys.path.insert(0, str(EMOVA_ROOT)) | |
| # --------------------------- | |
| # HuggingFace Hub helper | |
| # --------------------------- | |
| def ensure_hf_hub(target: str = "0.36.0"): | |
| """ | |
| Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers. | |
| """ | |
| try: | |
| import huggingface_hub as hub | |
| except ImportError: | |
| subprocess.check_call( | |
| [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"] | |
| ) | |
| import huggingface_hub as hub | |
| if parse_version(hub.__version__) >= parse_version("1.0.0"): | |
| subprocess.check_call( | |
| [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"] | |
| ) | |
| hub = importlib.reload(hub) | |
| # Backfill missing constants in older hub versions to avoid AttributeError. | |
| try: | |
| import huggingface_hub.constants as hub_consts # type: ignore | |
| except Exception: | |
| hub_consts = None | |
| if hub_consts and not hasattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER"): | |
| setattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER", False) | |
| return hub | |
| snapshot_download = ensure_hf_hub().snapshot_download | |
| # --------------------------- | |
| # OMada demo imports | |
| # --------------------------- | |
| from inference.gradio_multimodal_demo_inst import ( # noqa: E402 | |
| OmadaDemo, | |
| CUSTOM_CSS, | |
| FORCE_LIGHT_MODE_JS, | |
| ) | |
| # --------------------------- | |
| # HF download helpers | |
| # --------------------------- | |
| def download_assets() -> Path: | |
| """Download demo assets (logo + sample prompts/media) and return the root path.""" | |
| repo_id = os.getenv("ASSET_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion-assets") | |
| revision = os.getenv("ASSET_REVISION", "main") | |
| token = os.getenv("HF_TOKEN") | |
| cache_dir = PROJECT_ROOT / "_asset_cache" | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| return Path( | |
| snapshot_download( | |
| repo_id=repo_id, | |
| revision=revision, | |
| repo_type="dataset", | |
| local_dir=cache_dir, | |
| local_dir_use_symlinks=False, | |
| token=token, | |
| ) | |
| ) | |
| def download_style() -> Path: | |
| """Download style centroid dataset and return the root path.""" | |
| repo_id = os.getenv("STYLE_REPO_ID", "jaeikkim/aidas-style-centroid") | |
| revision = os.getenv("STYLE_REVISION", "main") | |
| token = os.getenv("HF_TOKEN") | |
| cache_dir = PROJECT_ROOT / "_style_cache" | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| return Path( | |
| snapshot_download( | |
| repo_id=repo_id, | |
| revision=revision, | |
| repo_type="dataset", | |
| local_dir=cache_dir, | |
| local_dir_use_symlinks=False, | |
| token=token, | |
| ) | |
| ) | |
| def download_checkpoint() -> Path: | |
| """Download checkpoint snapshot and return an `unwrapped_model` directory.""" | |
| local_override = os.getenv("MODEL_CHECKPOINT_PATH") | |
| if local_override: | |
| override_path = Path(local_override).expanduser() | |
| if override_path.name != "unwrapped_model": | |
| nested = override_path / "unwrapped_model" | |
| if nested.is_dir(): | |
| override_path = nested | |
| if not override_path.exists(): | |
| raise FileNotFoundError(f"MODEL_CHECKPOINT_PATH does not exist: {override_path}") | |
| return override_path | |
| repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion") | |
| revision = os.getenv("MODEL_REVISION", "main") | |
| token = os.getenv("HF_TOKEN") | |
| cache_dir = PROJECT_ROOT / "_ckpt_cache" | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| snapshot_path = Path( | |
| snapshot_download( | |
| repo_id=repo_id, | |
| revision=revision, | |
| repo_type="model", | |
| local_dir=cache_dir, | |
| local_dir_use_symlinks=False, | |
| token=token, | |
| ) | |
| ) | |
| if snapshot_path.name == "unwrapped_model": | |
| return snapshot_path | |
| nested = snapshot_path / "unwrapped_model" | |
| if nested.is_dir(): | |
| return nested | |
| aliased = snapshot_path.parent / "unwrapped_model" | |
| if not aliased.exists(): | |
| aliased.symlink_to(snapshot_path, target_is_directory=True) | |
| return aliased | |
| # --------------------------- | |
| # Assets (for examples + logo) | |
| # --------------------------- | |
| ASSET_ROOT = download_assets() | |
| STYLE_ROOT = download_style() | |
| LOGO_PATH = ASSET_ROOT / "logo.png" # optional | |
| def _load_text_examples(path: Path): | |
| if not path.exists(): | |
| return [] | |
| lines = [ | |
| ln.strip() | |
| for ln in path.read_text(encoding="utf-8").splitlines() | |
| if ln.strip() | |
| ] | |
| return [[ln] for ln in lines] | |
| def _load_media_examples(subdir: str, suffixes): | |
| d = ASSET_ROOT / subdir | |
| if not d.exists(): | |
| return [] | |
| ex = [] | |
| for p in sorted(d.iterdir()): | |
| if p.is_file() and p.suffix.lower() in suffixes: | |
| ex.append([str(p)]) | |
| return ex | |
| def _load_i2i_examples(): | |
| d = ASSET_ROOT / "i2i" | |
| if not d.exists(): | |
| return [] | |
| # μ΄λ―Έμ§ νμΌλ€ (image1.jpeg, image2.png, ...) | |
| image_files = sorted( | |
| [p for p in d.iterdir() if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}] | |
| ) | |
| # ν μ€νΈ νμΌλ€ (text1.txt, text2.txt, ...) | |
| text_files = sorted( | |
| [p for p in d.iterdir() if p.suffix.lower() == ".txt"] | |
| ) | |
| n = min(len(image_files), len(text_files)) | |
| examples = [] | |
| for i in range(2): | |
| img_path = image_files[i] | |
| txt_path = text_files[i] | |
| instruction = txt_path.read_text(encoding="utf-8").strip() | |
| if not instruction: | |
| continue | |
| # Gradio Examples νμ: [image, instruction_text] | |
| examples.append([str(img_path), instruction]) | |
| return examples | |
| def _load_ti2ti_examples(): | |
| """Load TI2TI examples: pairs of source image + instruction text.""" | |
| d = ASSET_ROOT / "ti2ti" | |
| if not d.exists(): | |
| return [] | |
| src_files = sorted( | |
| [p for p in d.iterdir() if p.is_file() and p.name.endswith("_src.png")], | |
| ) | |
| txt_files = {p.name.replace("_instr.txt", ""): p for p in d.iterdir() if p.is_file() and p.name.endswith("_instr.txt")} | |
| examples = [] | |
| for src in src_files: | |
| stem = src.name.replace("_src.png", "") | |
| txt = txt_files.get(stem) | |
| if not txt: | |
| continue | |
| instruction = txt.read_text(encoding="utf-8").strip() | |
| if not instruction: | |
| continue | |
| examples.append([str(src), instruction]) | |
| return examples | |
| # text-based examples | |
| T2S_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2s" / "text.txt") | |
| CHAT_EXAMPLES = _load_text_examples(ASSET_ROOT / "chat" / "text.txt") | |
| T2I_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2i" / "text.txt") | |
| I2I_EXAMPLES = _load_i2i_examples() | |
| TI2TI_EXAMPLES = _load_ti2ti_examples() | |
| # audio / video / image examples | |
| S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"}) | |
| S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"}) | |
| V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) | |
| V2S_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) | |
| # MMU images (and fallback for I2S) | |
| MMU_DIR = ASSET_ROOT / "mmu" | |
| MMU_EXAMPLES: List[List[str]] = [] | |
| if MMU_DIR.exists(): | |
| for path in sorted( | |
| [ | |
| p | |
| for p in MMU_DIR.iterdir() | |
| if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"} | |
| ] | |
| ): | |
| MMU_EXAMPLES.append([ | |
| str(path), | |
| "Describe the important objects and their relationships in this image.", | |
| ]) | |
| I2S_EXAMPLES = _load_media_examples("i2s", {".png", ".jpg", ".jpeg", ".webp"}) | |
| if not I2S_EXAMPLES and MMU_EXAMPLES: | |
| # use the first MMU sample image if no dedicated I2S example exists | |
| I2S_EXAMPLES = [[MMU_EXAMPLES[0][0]]] | |
| # --------------------------- | |
| # Global OmadaDemo instance | |
| # --------------------------- | |
| APP = None # type: ignore | |
| def get_app() -> OmadaDemo: | |
| global APP | |
| if APP is not None: | |
| return APP | |
| ckpt_dir = download_checkpoint() | |
| # Wire style centroids to expected locations | |
| style_targets = [ | |
| MMADA_ROOT / "models" / "speech_tokenization" / "condition_style_centroid", | |
| PROJECT_ROOT | |
| / "EMOVA_speech_tokenizer" | |
| / "emova_speech_tokenizer" | |
| / "speech_tokenization" | |
| / "condition_style_centroid", | |
| ] | |
| for starget in style_targets: | |
| if not starget.exists(): | |
| starget.parent.mkdir(parents=True, exist_ok=True) | |
| starget.symlink_to(STYLE_ROOT, target_is_directory=True) | |
| default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml" | |
| legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml" | |
| train_config = os.getenv("TRAIN_CONFIG_PATH") | |
| if not train_config: | |
| train_config = str(default_cfg if default_cfg.exists() else legacy_cfg) | |
| device = os.getenv("DEVICE", "cuda") | |
| APP = OmadaDemo(train_config=train_config, checkpoint=str(ckpt_dir), device=device) | |
| return APP | |
| # --------------------------- | |
| # ZeroGPU-wrapped handlers | |
| # --------------------------- | |
| # (== κ·Έλλ‘, μλ΅ μμ΄ λ λΆλΆ ==) | |
| def t2s_handler(text, max_tokens, steps, block_len, temperature, cfg_scale, gender, emotion, speed, pitch): | |
| app = get_app() | |
| audio, status = app.run_t2s( | |
| text=text, | |
| max_new_tokens=int(max_tokens), | |
| steps=int(steps), | |
| block_length=int(block_len), | |
| temperature=float(temperature), | |
| cfg_scale=float(cfg_scale), | |
| gender_choice=gender, | |
| emotion_choice=emotion, | |
| speed_choice=speed, | |
| pitch_choice=pitch, | |
| ) | |
| return audio, status | |
| def s2s_handler(audio_path, max_tokens, steps, block_len, temperature, cfg_scale): | |
| app = get_app() | |
| audio, status = app.run_s2s( | |
| audio_path=audio_path, | |
| max_new_tokens=int(max_tokens), | |
| steps=int(steps), | |
| block_length=int(block_len), | |
| temperature=float(temperature), | |
| cfg_scale=float(cfg_scale), | |
| ) | |
| return audio, status | |
| def s2t_handler(audio_path, steps, block_len, max_tokens, remasking): | |
| app = get_app() | |
| text, status = app.run_s2t( | |
| audio_path=audio_path, | |
| steps=int(steps), | |
| block_length=int(block_len), | |
| max_new_tokens=int(max_tokens), | |
| remasking=str(remasking), | |
| ) | |
| return text, status | |
| def v2t_handler(video, steps, block_len, max_tokens): | |
| app = get_app() | |
| text, status = app.run_v2t( | |
| video_path=video, | |
| steps=int(steps), | |
| block_length=int(block_len), | |
| max_new_tokens=int(max_tokens), | |
| ) | |
| return text, status | |
| def v2s_handler(video, message, max_tokens, steps, block_len, temperature, cfg_scale): | |
| app = get_app() | |
| audio, status = app.run_v2s( | |
| video_path=video, | |
| message=message, | |
| max_new_tokens=int(max_tokens), | |
| steps=int(steps), | |
| block_length=int(block_len), | |
| temperature=float(temperature), | |
| cfg_scale=float(cfg_scale), | |
| ) | |
| return audio, status | |
| def i2s_handler(image, message, max_tokens, steps, block_len, temperature, cfg_scale): | |
| app = get_app() | |
| audio, status = app.run_i2s( | |
| image=image, | |
| message=message, | |
| max_new_tokens=int(max_tokens), | |
| steps=int(steps), | |
| block_length=int(block_len), | |
| temperature=float(temperature), | |
| cfg_scale=float(cfg_scale), | |
| ) | |
| return audio, status | |
| def chat_handler(message, max_tokens, steps, block_len, temperature): | |
| app = get_app() | |
| text, status = app.run_chat( | |
| message=message, | |
| max_new_tokens=int(max_tokens), | |
| steps=int(steps), | |
| block_length=int(block_len), | |
| temperature=float(temperature), | |
| ) | |
| return text, status | |
| def mmu_handler(image, question, max_tokens, steps, block_len, temperature): | |
| app = get_app() | |
| text, status = app.run_mmu( | |
| images=image, | |
| message=question, | |
| max_new_tokens=int(max_tokens), | |
| steps=int(steps), | |
| block_length=int(block_len), | |
| temperature=float(temperature), | |
| ) | |
| return text, status | |
| def t2i_handler(prompt, timesteps, temperature, guidance): | |
| app = get_app() | |
| image, status = app.run_t2i( | |
| prompt=prompt, | |
| timesteps=int(timesteps), | |
| temperature=float(temperature), | |
| guidance_scale=float(guidance), | |
| ) | |
| return image, status | |
| def i2i_handler(instruction, image, timesteps, temperature, guidance): | |
| app = get_app() | |
| image_out, status = app.run_i2i( | |
| instruction=instruction, | |
| source_image=image, | |
| timesteps=int(timesteps), | |
| temperature=float(temperature), | |
| guidance_scale=float(guidance), | |
| ) | |
| return image_out, status | |
| def ti2ti_handler(instruction, image, text_tokens, timesteps_image, timesteps_text, temperature, guidance): | |
| app = get_app() | |
| image_out, text_out, status = app.run_ti2ti( | |
| instruction=instruction, | |
| source_image=image, | |
| text_tokens=int(text_tokens), | |
| timesteps_image=int(timesteps_image), | |
| timesteps_text=int(timesteps_text), | |
| temperature=float(temperature), | |
| guidance_scale=float(guidance), | |
| ) | |
| return image_out, text_out, status | |
| # --------------------------- | |
| # Gradio UI (10 tabs + examples) | |
| # --------------------------- | |
| theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray") | |
| with gr.Blocks( | |
| title="AIDAS Lab @ SNU - Omni-modal Diffusion", | |
| css=CUSTOM_CSS, | |
| theme=theme, | |
| js=FORCE_LIGHT_MODE_JS, | |
| ) as demo: | |
| with gr.Row(): | |
| if LOGO_PATH.exists(): | |
| gr.Image( | |
| value=str(LOGO_PATH), | |
| show_label=False, | |
| height=80, | |
| interactive=False, | |
| ) | |
| gr.Markdown( | |
| "## Omni-modal Diffusion Foundation Model\n" | |
| "### AIDAS Lab @ SNU" | |
| ) | |
| # ---- T2S ---- | |
| with gr.Tab("Text β Speech (T2S)"): | |
| with gr.Row(): | |
| t2s_text = gr.Textbox( | |
| label="Input text", | |
| lines=4, | |
| placeholder="Type the speech you want to synthesize...", | |
| ) | |
| t2s_audio = gr.Audio(label="Generated speech", type="numpy") | |
| t2s_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| t2s_max_tokens = gr.Slider(2, 512, value=384, step=2, label="Speech token length") | |
| t2s_steps = gr.Slider(2, 512, value=128, step=2, label="Total refinement steps") | |
| t2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length") | |
| t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") | |
| t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, step=0.1, label="CFG scale") | |
| with gr.Row(): | |
| t2s_gender = gr.Dropdown(["random", "female", "male"], value="random", label="Gender") | |
| t2s_emotion = gr.Dropdown(["random", "angry", "happy", "neutral", "sad"], value="random", label="Emotion") | |
| with gr.Row(): | |
| t2s_speed = gr.Dropdown(["random", "normal", "fast", "slow"], value="random", label="Speed") | |
| t2s_pitch = gr.Dropdown(["random", "normal", "high", "low"], value="random", label="Pitch") | |
| if T2S_EXAMPLES: | |
| with gr.Accordion("Sample prompts", open=False): | |
| gr.Examples( | |
| examples=T2S_EXAMPLES, | |
| inputs=[t2s_text], | |
| examples_per_page=6, | |
| ) | |
| t2s_btn = gr.Button("Generate speech", variant="primary") | |
| t2s_btn.click( | |
| t2s_handler, | |
| inputs=[ | |
| t2s_text, | |
| t2s_max_tokens, | |
| t2s_steps, | |
| t2s_block, | |
| t2s_temperature, | |
| t2s_cfg, | |
| t2s_gender, | |
| t2s_emotion, | |
| t2s_speed, | |
| t2s_pitch, | |
| ], | |
| outputs=[t2s_audio, t2s_status], | |
| ) | |
| # ---- S2S ---- | |
| with gr.Tab("Speech β Speech (S2S)"): | |
| s2s_audio_in = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"]) | |
| s2s_audio_out = gr.Audio(type="numpy", label="Reply speech") | |
| s2s_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| s2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length") | |
| s2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps") | |
| s2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length") | |
| s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="Sampling temperature") | |
| s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, step=0.1, label="CFG scale") | |
| if S2S_EXAMPLES: | |
| with gr.Accordion("Sample clips", open=False): | |
| gr.Examples( | |
| examples=S2S_EXAMPLES, | |
| inputs=[s2s_audio_in], | |
| examples_per_page=4, | |
| ) | |
| s2s_btn = gr.Button("Generate reply speech", variant="primary") | |
| s2s_btn.click( | |
| s2s_handler, | |
| inputs=[ | |
| s2s_audio_in, | |
| s2s_max_tokens, | |
| s2s_steps, | |
| s2s_block, | |
| s2s_temperature, | |
| s2s_cfg, | |
| ], | |
| outputs=[s2s_audio_out, s2s_status], | |
| ) | |
| # ---- S2T ---- | |
| with gr.Tab("Speech β Text (S2T)"): | |
| s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"]) | |
| s2t_text_out = gr.Textbox(label="Transcription", lines=4) | |
| s2t_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| s2t_steps = gr.Slider(2, 512, value=128, step=2, label="Denoising steps") | |
| s2t_block = gr.Slider(2, 512, value=128, step=2, label="Block length") | |
| s2t_max_tokens = gr.Slider(2, 512, value=128, step=2, label="Max new tokens") | |
| s2t_remasking = gr.Dropdown( | |
| ["low_confidence", "random"], | |
| value="low_confidence", | |
| label="Remasking strategy", | |
| ) | |
| if S2T_EXAMPLES: | |
| with gr.Accordion("Sample clips", open=False): | |
| gr.Examples( | |
| examples=S2T_EXAMPLES, | |
| inputs=[s2t_audio_in], | |
| examples_per_page=4, | |
| ) | |
| s2t_btn = gr.Button("Transcribe", variant="primary") | |
| s2t_btn.click( | |
| s2t_handler, | |
| inputs=[s2t_audio_in, s2t_steps, s2t_block, s2t_max_tokens, s2t_remasking], | |
| outputs=[s2t_text_out, s2t_status], | |
| ) | |
| # ---- V2T ---- | |
| with gr.Tab("Video β Text (V2T)"): | |
| v2t_video_in = gr.Video( | |
| label="Upload or record video", | |
| height=256, | |
| sources=["upload", "webcam"], | |
| ) | |
| v2t_text_out = gr.Textbox(label="Caption / answer", lines=4) | |
| v2t_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| v2t_steps = gr.Slider(2, 512, value=64, step=2, label="Denoising steps") | |
| v2t_block = gr.Slider(2, 512, value=64, step=2, label="Block length") | |
| v2t_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Max new tokens") | |
| if V2T_EXAMPLES: | |
| with gr.Accordion("Sample videos", open=False): | |
| gr.Examples( | |
| examples=V2T_EXAMPLES, | |
| inputs=[v2t_video_in], | |
| examples_per_page=4, | |
| ) | |
| v2t_btn = gr.Button("Generate caption", variant="primary") | |
| v2t_btn.click( | |
| v2t_handler, | |
| inputs=[v2t_video_in, v2t_steps, v2t_block, v2t_max_tokens], | |
| outputs=[v2t_text_out, v2t_status], | |
| ) | |
| # ---- V2S ---- | |
| with gr.Tab("Video β Speech (V2S)"): | |
| v2s_video_in = gr.Video( | |
| label="Upload or record video", | |
| height=256, | |
| sources=["upload", "webcam"], | |
| ) | |
| v2s_prompt = gr.Textbox( | |
| label="Optional instruction", | |
| placeholder="(Optional) e.g., 'Describe this scene in spoken form.'", | |
| ) | |
| v2s_audio_out = gr.Audio(type="numpy", label="Generated speech") | |
| v2s_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| v2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length") | |
| v2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps") | |
| v2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length") | |
| v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") | |
| v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale") | |
| # (optional v2s examples: if you later add 'v2s' folder, same ν¨ν΄μΌλ‘ λΆμ΄λ©΄ λ¨) | |
| if V2T_EXAMPLES: | |
| with gr.Accordion("Sample videos", open=False): | |
| gr.Examples( | |
| examples=V2T_EXAMPLES, | |
| inputs=[v2t_video_in], | |
| examples_per_page=4, | |
| ) | |
| v2s_btn = gr.Button("Generate speech from video", variant="primary") | |
| v2s_btn.click( | |
| v2s_handler, | |
| inputs=[ | |
| v2s_video_in, | |
| v2s_prompt, | |
| v2s_max_tokens, | |
| v2s_steps, | |
| v2s_block, | |
| v2s_temperature, | |
| v2s_cfg, | |
| ], | |
| outputs=[v2s_audio_out, v2s_status], | |
| ) | |
| # ---- T2I ---- | |
| with gr.Tab("Text β Image (T2I)"): | |
| t2i_prompt = gr.Textbox( | |
| label="Prompt", | |
| lines=4, | |
| placeholder="Describe the image you want to generate...", | |
| ) | |
| t2i_image_out = gr.Image(label="Generated image") | |
| t2i_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps") | |
| t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") | |
| t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale") | |
| if T2I_EXAMPLES: | |
| with gr.Accordion("Sample prompts", open=False): | |
| gr.Examples( | |
| examples=T2I_EXAMPLES, | |
| inputs=[t2i_prompt], | |
| examples_per_page=6, | |
| ) | |
| t2i_btn = gr.Button("Generate image", variant="primary") | |
| t2i_btn.click( | |
| t2i_handler, | |
| inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance], | |
| outputs=[t2i_image_out, t2i_status], | |
| ) | |
| # ---- I2I ---- | |
| with gr.Tab("Image Editing (I2I)"): | |
| i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"]) | |
| i2i_instr = gr.Textbox( | |
| label="Editing instruction", | |
| lines=4, | |
| placeholder="Describe how you want to edit the image...", | |
| ) | |
| i2i_image_out = gr.Image(label="Edited image") | |
| i2i_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| i2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps") | |
| i2i_temperature = gr.Slider(0.0, 2.0, value=0.3, step=0.05, label="Sampling temperature") | |
| i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale") | |
| if I2I_EXAMPLES: | |
| with gr.Accordion("Sample edits", open=False): | |
| gr.Examples( | |
| examples=I2I_EXAMPLES, | |
| inputs=[i2i_image_in, i2i_instr], | |
| examples_per_page=4, | |
| ) | |
| i2i_btn = gr.Button("Apply edit", variant="primary") | |
| i2i_btn.click( | |
| i2i_handler, | |
| inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance], | |
| outputs=[i2i_image_out, i2i_status], | |
| ) | |
| # ---- TI2TI ---- | |
| with gr.Tab("Text+Image β Text+Image (TI2TI)"): | |
| ti2ti_image_in = gr.Image(type="pil", label="Source image", sources=["upload"]) | |
| ti2ti_instr = gr.Textbox( | |
| label="Editing instruction", | |
| lines=4, | |
| placeholder="Describe how you want the image edited and what to say about it...", | |
| ) | |
| ti2ti_image_out = gr.Image(label="Edited image") | |
| ti2ti_text_out = gr.Textbox(label="Generated text", lines=4) | |
| ti2ti_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| ti2ti_text_tokens = gr.Slider(8, 256, value=64, step=4, label="Text placeholder tokens") | |
| ti2ti_img_steps = gr.Slider(4, 128, value=64, step=2, label="Image timesteps") | |
| ti2ti_text_steps = gr.Slider(4, 128, value=64, step=2, label="Text timesteps") | |
| ti2ti_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") | |
| ti2ti_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale") | |
| if TI2TI_EXAMPLES: | |
| with gr.Accordion("Sample edits", open=False): | |
| gr.Examples( | |
| examples=TI2TI_EXAMPLES, | |
| inputs=[ti2ti_image_in, ti2ti_instr], | |
| examples_per_page=4, | |
| ) | |
| ti2ti_btn = gr.Button("Generate edited image + text", variant="primary") | |
| ti2ti_btn.click( | |
| ti2ti_handler, | |
| inputs=[ | |
| ti2ti_instr, | |
| ti2ti_image_in, | |
| ti2ti_text_tokens, | |
| ti2ti_img_steps, | |
| ti2ti_text_steps, | |
| ti2ti_temperature, | |
| ti2ti_guidance, | |
| ], | |
| outputs=[ti2ti_image_out, ti2ti_text_out, ti2ti_status], | |
| ) | |
| # ---- I2S ---- | |
| with gr.Tab("Image β Speech (I2S)"): | |
| i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"]) | |
| i2s_prompt = gr.Textbox( | |
| label="Optional question", | |
| placeholder="(Optional) e.g., 'Describe this image aloud.'", | |
| ) | |
| i2s_audio_out = gr.Audio(type="numpy", label="Spoken description") | |
| i2s_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| i2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length") | |
| i2s_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps") | |
| i2s_block = gr.Slider(2, 512, value=256, step=2, label="Block length") | |
| i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature") | |
| i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale") | |
| if I2S_EXAMPLES: | |
| with gr.Accordion("Sample images", open=False): | |
| gr.Examples( | |
| examples=I2S_EXAMPLES, | |
| inputs=[i2s_image_in], | |
| examples_per_page=4, | |
| ) | |
| i2s_btn = gr.Button("Generate spoken description", variant="primary") | |
| i2s_btn.click( | |
| i2s_handler, | |
| inputs=[ | |
| i2s_image_in, | |
| i2s_prompt, | |
| i2s_max_tokens, | |
| i2s_steps, | |
| i2s_block, | |
| i2s_temperature, | |
| i2s_cfg, | |
| ], | |
| outputs=[i2s_audio_out, i2s_status], | |
| ) | |
| # ---- Chat ---- | |
| with gr.Tab("Text Chat"): | |
| chat_in = gr.Textbox( | |
| label="Message", | |
| lines=4, | |
| placeholder="Ask anything. The model will reply in text.", | |
| ) | |
| chat_out = gr.Textbox(label="Assistant reply", lines=6) | |
| chat_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| chat_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Reply max tokens") | |
| chat_steps = gr.Slider(2, 512, value=64, step=2, label="Refinement steps") | |
| chat_block = gr.Slider(2, 512, value=64, step=2, label="Block length") | |
| chat_temperature_slider = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Sampling temperature") | |
| if CHAT_EXAMPLES: | |
| with gr.Accordion("Sample prompts", open=False): | |
| gr.Examples( | |
| examples=CHAT_EXAMPLES, | |
| inputs=[chat_in], | |
| examples_per_page=6, | |
| ) | |
| chat_btn = gr.Button("Send", variant="primary") | |
| chat_btn.click( | |
| chat_handler, | |
| inputs=[ | |
| chat_in, | |
| chat_max_tokens, | |
| chat_steps, | |
| chat_block, | |
| chat_temperature_slider, | |
| ], | |
| outputs=[chat_out, chat_status], | |
| ) | |
| # ---- MMU ---- | |
| with gr.Tab("MMU (Image β Text)"): | |
| mmu_img = gr.Image(type="pil", label="Input image", sources=["upload"]) | |
| mmu_question = gr.Textbox( | |
| label="Question", | |
| lines=3, | |
| placeholder="Ask about the scene, objects, or context of the image.", | |
| ) | |
| mmu_answer = gr.Textbox(label="Answer", lines=6) | |
| mmu_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced settings", open=False): | |
| mmu_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Answer max tokens") | |
| mmu_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps") | |
| mmu_block = gr.Slider(2, 512, value=128, step=2, label="Block length") | |
| mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Sampling temperature") | |
| if MMU_EXAMPLES: | |
| with gr.Accordion("Sample MMU prompts", open=False): | |
| gr.Examples( | |
| examples=MMU_EXAMPLES, | |
| inputs=[mmu_img, mmu_question], | |
| examples_per_page=1, | |
| ) | |
| mmu_btn = gr.Button("Answer about the image", variant="primary") | |
| mmu_btn.click( | |
| mmu_handler, | |
| inputs=[ | |
| mmu_img, | |
| mmu_question, | |
| mmu_max_tokens, | |
| mmu_steps, | |
| mmu_block, | |
| mmu_temperature, | |
| ], | |
| outputs=[mmu_answer, mmu_status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |