Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| Gradio demo for OMada Stage 1.3 checkpoints covering: | |
| * Text-to-Speech (T2S) | |
| * Speech-to-Text (S2T) | |
| * Video-to-Text (V2T) | |
| The implementation wraps the existing CLI inference helpers so that a single | |
| checkpoint directory (…/checkpoint-XXXX/unwrapped_model) can be previewed | |
| interactively. Usage: | |
| python MMaDA/inference/gradio_multimodal_demo_inst.py --train-config /t1data/users/snu-lab-d/omada/OMaDA/MMaDA/inference/demo/demo.yaml --checkpoint ../ckpt/checkpoint-400000/unwrapped_model/ --server-port 7860 | |
| If you need remote access, pass `--share` which simply forwards the flag to | |
| `gradio.Blocks.launch`. For more reliable sharing, run the demo locally without | |
| `--share` and tunnel the chosen port with a tool such as `ngrok http 7860` or | |
| `cloudflared tunnel --url http://localhost:7860` instead of relying on Gradio’s | |
| temporary share links. | |
| """ | |
| import argparse | |
| import base64 | |
| import html | |
| import io | |
| import os | |
| import math | |
| import random | |
| import sys | |
| import tempfile | |
| import wave | |
| from pathlib import Path | |
| import shutil | |
| import time | |
| from typing import Any, List, Optional, Sequence, Tuple, Union | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| CUSTOM_CSS = """ | |
| @import url('https://cdn.jsdelivr.net/gh/orioncactus/[email protected]/dist/web/variable/pretendardvariable-dynamic-subset.css'); | |
| :root { | |
| --omada-font-ui: "Pretendard Variable", -apple-system, BlinkMacSystemFont, | |
| "Segoe UI", system-ui, sans-serif; | |
| } | |
| /* 전체 기본 폰트 */ | |
| html, body, .gradio-container { | |
| font-family: var(--omada-font-ui) !important; | |
| } | |
| /* Gradio 컴포넌트들까지 강제로 통일 */ | |
| .gradio-container * { | |
| font-family: var(--omada-font-ui) !important; | |
| } | |
| /* 혹시 일부 컴포넌트가 자체 폰트 지정해버리는 경우를 위한 안전핀 */ | |
| .gradio-textbox textarea, | |
| .gradio-dropdown select, | |
| .gradio-radio label, | |
| .gradio-button button, | |
| .gradio-slider label, | |
| .gradio-accordion *, | |
| .gradio-chatbot * { | |
| font-family: var(--omada-font-ui) !important; | |
| } | |
| :root { | |
| --omada-primary: #1e3a8a; | |
| --omada-accent: #1d4ed8; | |
| --omada-surface: #f3f4f6; | |
| --omada-surface-alt: #ffffff; | |
| --omada-border: #d0d7e5; | |
| --omada-text-primary: #111827; | |
| --omada-text-muted: #374151; | |
| color-scheme: light; | |
| } | |
| html, body, body.dark, html.dark { | |
| background: var(--omada-surface) !important; | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-container { | |
| background: var(--omada-surface); | |
| color: var(--omada-text-primary); | |
| } | |
| .omada-page-heading { | |
| margin-bottom: 0; | |
| color: var(--omada-text-primary); | |
| } | |
| .omada-tab-intro p { | |
| font-size: 0.95rem; | |
| color: var(--omada-text-primary); | |
| margin-top: 0; | |
| opacity: 0.9; | |
| } | |
| .omada-card { | |
| background: var(--omada-surface-alt); | |
| border-radius: 16px !important; | |
| padding: 18px !important; | |
| box-shadow: none; | |
| border: 1px solid var(--omada-border); | |
| color: var(--omada-text-primary); | |
| } | |
| .omada-card .gradio-slider .wrap-inner { | |
| gap: 6px; | |
| } | |
| .omada-card .gradio-slider input[type=range]::-webkit-slider-thumb { | |
| background: var(--omada-primary); | |
| } | |
| .gradio-slider input[type=range]::-webkit-slider-runnable-track { | |
| background: rgba(14, 33, 80, 0.2); | |
| } | |
| .omada-section-title p { | |
| text-transform: uppercase; | |
| font-size: 0.78rem; | |
| letter-spacing: 0.14em; | |
| color: rgba(30, 58, 138, 0.85); | |
| margin: 0 0 12px 0; | |
| } | |
| .omada-output .gradio-audio, .omada-output .gradio-textbox { | |
| margin-top: 12px; | |
| } | |
| .gradio-textbox, .gradio-dropdown, .gradio-slider, .gradio-audio { | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-dropdown .wrap .label, .gradio-textbox label, .gradio-slider label { | |
| color: var(--omada-text-primary); | |
| } | |
| .gradio-dropdown .single-select, .gradio-textbox textarea { | |
| background: #ffffff !important; | |
| border: 1px solid var(--omada-border) !important; | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-dropdown .single-select select, .gradio-textbox textarea { | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-textbox textarea::placeholder { | |
| color: rgba(148, 163, 184, 0.65); | |
| } | |
| .gradio-dropdown, .gradio-textbox, .gradio-audio, .gradio-video, .gradio-slider { | |
| background: #ffffff !important; | |
| border: 1px solid var(--omada-border) !important; | |
| border-radius: 12px !important; | |
| } | |
| .full-width-button button { | |
| width: 100%; | |
| background: var(--omada-primary) !important; | |
| color: white !important; | |
| border: none !important; | |
| font-weight: 600; | |
| transition: transform 0.2s ease, box-shadow 0.2s ease; | |
| box-shadow: 0 12px 30px -12px rgba(79, 70, 229, 0.65); | |
| } | |
| .full-width-button button:hover { | |
| transform: translateY(-1px); | |
| box-shadow: 0 18px 34px -14px rgba(79, 70, 229, 0.75); | |
| } | |
| .omada-advanced .gr-accordion-header { | |
| font-size: 0.85rem; | |
| letter-spacing: 0.05em; | |
| color: var(--omada-text-muted); | |
| } | |
| .omada-advanced .gr-accordion { | |
| border: 1px solid var(--omada-border); | |
| border-radius: 12px; | |
| background: #ffffff; | |
| } | |
| .gradio-tabs { | |
| background: transparent; | |
| } | |
| .gradio-tabs ul.tab-list { | |
| background: transparent; | |
| border-bottom: 1px solid var(--omada-border); | |
| } | |
| .gradio-tabs button { | |
| color: var(--omada-text-primary); | |
| } | |
| .gradio-tabs button.selected { | |
| color: var(--omada-text-primary); | |
| background: rgba(14, 33, 80, 0.1); | |
| border-bottom: 2px solid var(--omada-primary); | |
| } | |
| .gradio-container .label { | |
| background: rgba(30, 58, 138, 0.1) !important; | |
| color: var(--omada-primary) !important; | |
| border: 1px solid rgba(30, 58, 138, 0.25) !important; | |
| border-radius: 999px !important; | |
| padding: 4px 12px !important; | |
| } | |
| .gradio-button.primary { | |
| background: var(--omada-primary) !important; | |
| color: #ffffff !important; | |
| border: 1px solid var(--omada-primary) !important; | |
| } | |
| .gradio-accordion { | |
| box-shadow: none; | |
| } | |
| .omada-layout { | |
| gap: 20px !important; | |
| } | |
| .omada-chat-column { | |
| gap: 12px !important; | |
| } | |
| .omada-chat-column .gradio-chatbot { | |
| border-radius: 16px; | |
| box-shadow: none; | |
| border: 1px solid var(--omada-border); | |
| background: #ffffff; | |
| } | |
| .omada-controls { | |
| gap: 16px !important; | |
| } | |
| .omada-mode-panel { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 16px !important; | |
| } | |
| .omada-examples-card { | |
| padding-top: 10px !important; | |
| } | |
| .omada-output-panel .gradio-audio, | |
| .omada-output-panel .gradio-textbox { | |
| margin-top: 8px; | |
| } | |
| .omada-response-container { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 10px; | |
| } | |
| .omada-response-status { | |
| margin: 0; | |
| font-weight: 600; | |
| font-size: 0.95rem; | |
| color: var(--omada-text-primary); | |
| } | |
| .omada-response-block, | |
| .omada-audio-block { | |
| background: rgba(30, 58, 138, 0.05); | |
| border-radius: 12px; | |
| padding: 12px 14px; | |
| color: var(--omada-text-primary); | |
| white-space: pre-wrap; | |
| word-break: break-word; | |
| } | |
| .omada-audio-block audio { | |
| width: 100%; | |
| } | |
| .omada-header { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 18px !important; | |
| margin-bottom: 18px; | |
| text-align: center; | |
| } | |
| .omada-header .gradio-image { | |
| background: transparent !important; | |
| border: none !important; | |
| } | |
| .omada-header img { | |
| object-fit: contain; | |
| display: block; | |
| } | |
| .omada-logo { | |
| max-width: 180px; | |
| padding: 0 !important; | |
| } | |
| .omada-examples { | |
| margin-top: 8px; | |
| padding-top: 4px; | |
| } | |
| .omada-examples .gradio-dataset { | |
| width: 100% !important; | |
| } | |
| .omada-examples .samples-table { | |
| width: 100% !important; | |
| } | |
| .omada-examples table { | |
| width: 100% !important; | |
| } | |
| .omada-examples td { | |
| width: 100% !important; | |
| } | |
| .omada-examples .sample-button, | |
| .omada-examples button { | |
| width: 100% !important; | |
| white-space: pre-wrap !important; | |
| word-wrap: break-word !important; | |
| height: auto !important; | |
| min-height: 40px !important; | |
| text-align: left !important; | |
| padding: 12px 16px !important; | |
| line-height: 1.5 !important; | |
| overflow: visible !important; | |
| text-overflow: clip !important; | |
| display: block !important; | |
| } | |
| .omada-examples button span { | |
| white-space: pre-wrap !important; | |
| word-wrap: break-word !important; | |
| overflow: visible !important; | |
| text-overflow: clip !important; | |
| display: block !important; | |
| width: 100% !important; | |
| } | |
| .omada-logo .gradio-image, | |
| .omada-logo .gradio-image > div, | |
| .omada-logo .gradio-image .container { | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| padding: 0 !important; | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| .omada-logo img { | |
| width: 100%; | |
| height: auto; | |
| } | |
| .omada-logo button { | |
| display: none !important; | |
| } | |
| .gradio-container .gradio-component, | |
| .gradio-container .gradio-panel, | |
| .gradio-container .gradio-box { | |
| background: transparent !important; | |
| color: var(--omada-text-primary); | |
| } | |
| .dark .gradio-container, | |
| .dark .gradio-interface, | |
| .dark .gradio-container * { | |
| background-color: inherit !important; | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .dark .gradio-container .gradio-chatbot, | |
| .dark .gradio-container .gradio-dropdown, | |
| .dark .gradio-container .gradio-textbox, | |
| .dark .gradio-container .gradio-audio, | |
| .dark .gradio-container .gradio-video, | |
| .dark .gradio-container .gradio-slider, | |
| .dark .gradio-container .gradio-accordion, | |
| .dark .gradio-container .gradio-panel, | |
| .dark .gradio-container .gradio-box { | |
| background: #ffffff !important; | |
| border-color: var(--omada-border) !important; | |
| } | |
| .omada-title h2 { | |
| font-size: 2.4rem; | |
| font-weight: 700; | |
| color: var(--omada-text-primary); | |
| margin: 0; | |
| } | |
| .omada-title h3 { | |
| font-size: 1.25rem; | |
| font-weight: 600; | |
| letter-spacing: 0.1em; | |
| text-transform: uppercase; | |
| color: var(--omada-text-muted); | |
| margin: 6px 0 0; | |
| } | |
| .omada-tagline p { | |
| color: var(--omada-text-primary); | |
| font-size: 1rem; | |
| margin: 0; | |
| opacity: 0.9; | |
| line-height: 1.45; | |
| } | |
| .omada-tagline .tagline-speech { color: #f97316; font-weight: 600; } | |
| .omada-tagline .tagline-audio { color: #ec4899; font-weight: 600; } | |
| .omada-tagline .tagline-video { color: #0ea5e9; font-weight: 600; } | |
| .omada-tagline .tagline-text { color: #a855f7; font-weight: 600; } | |
| .omada-tagline .tagline-image { color: #22c55e; font-weight: 600; } | |
| .gradio-container .prose :where(h1, h2, h3, h4, h5, h6) { | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-container .prose :where(p, li) { | |
| color: var(--omada-text-muted) !important; | |
| } | |
| .gradio-container label, .gradio-container span, .gradio-container button { | |
| color: var(--omada-text-primary); | |
| } | |
| .gradio-container .dark { | |
| background: #ffffff !important; | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .omada-logo-img { | |
| max-width: 250px; | |
| width: 100%; | |
| height: auto; | |
| display: block; | |
| margin: 0 auto; | |
| } | |
| .omada-logo-wrapper { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| """ | |
| FORCE_LIGHT_MODE_JS = """ | |
| function() { | |
| document.body.classList.remove('dark'); | |
| document.documentElement.classList.remove('dark'); | |
| const observer = new MutationObserver(function(mutations) { | |
| document.body.classList.remove('dark'); | |
| document.documentElement.classList.remove('dark'); | |
| }); | |
| observer.observe(document.body, { attributes: true, attributeFilter: ['class'] }); | |
| observer.observe(document.documentElement, { attributes: true, attributeFilter: ['class'] }); | |
| } | |
| """ | |
| DEMO_ROOT = Path(__file__).resolve().parent / "demo" | |
| LOGO_PATH = DEMO_ROOT / "logo.png" | |
| T2S_TEXT_PATH = DEMO_ROOT / "t2s" / "text.txt" | |
| CHAT_TEXT_PATH = DEMO_ROOT / "chat" / "text.txt" | |
| T2I_TEXT_PATH = DEMO_ROOT / "t2i" / "text.txt" | |
| def _load_logo_data() -> Optional[str]: | |
| if not LOGO_PATH.exists(): | |
| return None | |
| try: | |
| import base64 | |
| except ImportError: | |
| return str(LOGO_PATH) | |
| try: | |
| encoded = base64.b64encode(LOGO_PATH.read_bytes()).decode("utf-8") | |
| except OSError: | |
| return str(LOGO_PATH) | |
| return f"data:image/png;base64,{encoded}" | |
| def _load_t2s_examples(): | |
| if not T2S_TEXT_PATH.exists(): | |
| return [] | |
| lines = [ | |
| line.strip() | |
| for line in T2S_TEXT_PATH.read_text(encoding="utf-8").splitlines() | |
| if line.strip() | |
| ] | |
| return [[line] for line in lines] | |
| def _load_chat_examples(): | |
| if not CHAT_TEXT_PATH.exists(): | |
| return [] | |
| lines = [ | |
| line.strip() | |
| for line in CHAT_TEXT_PATH.read_text(encoding="utf-8").splitlines() | |
| if line.strip() | |
| ] | |
| return [[line] for line in lines] | |
| def _load_t2i_examples(): | |
| if not T2I_TEXT_PATH.exists(): | |
| return [] | |
| lines = [ | |
| line.strip() | |
| for line in T2I_TEXT_PATH.read_text(encoding="utf-8").splitlines() | |
| if line.strip() | |
| ] | |
| return [[line] for line in lines] | |
| def _load_i2i_examples(): | |
| """demo/i2i 안의 image*.jpg + text*.txt 쌍을 Examples로 묶어줌.""" | |
| d = DEMO_ROOT / "i2i" | |
| if not d.exists(): | |
| return [] | |
| # 이미지 파일들 (image1.jpeg, image2.png, ...) | |
| image_files = sorted( | |
| [p for p in d.iterdir() if p.is_file() and p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}] | |
| ) | |
| # 텍스트 파일들 (text1.txt, text2.txt, ...) | |
| text_files = sorted( | |
| [p for p in d.iterdir() if p.is_file() and p.suffix.lower() == ".txt"] | |
| ) | |
| n = min(len(image_files), len(text_files)) | |
| examples = [] | |
| for i in range(n): | |
| img_path = image_files[i] | |
| txt_path = text_files[i] | |
| instruction = txt_path.read_text(encoding="utf-8").strip() | |
| if not instruction: | |
| continue | |
| # Gradio Examples 형식: [image, instruction_text] | |
| examples.append([str(img_path), instruction]) | |
| return examples | |
| def _load_media_examples(subdir: str, suffixes): | |
| target_dir = DEMO_ROOT / subdir | |
| if not target_dir.exists(): | |
| return [] | |
| examples = [] | |
| for path in sorted(target_dir.iterdir()): | |
| if path.is_file() and path.suffix.lower() in suffixes: | |
| examples.append([str(path)]) | |
| return examples | |
| T2S_EXAMPLES = _load_t2s_examples() | |
| CHAT_EXAMPLES = _load_chat_examples() | |
| T2I_EXAMPLES = _load_t2i_examples() | |
| I2I_EXAMPLES = _load_i2i_examples() | |
| S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"}) | |
| V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) | |
| S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"}) | |
| if not S2S_EXAMPLES: | |
| S2S_EXAMPLES = S2T_EXAMPLES[: min(4, len(S2T_EXAMPLES))] | |
| V2S_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) | |
| if not V2S_EXAMPLES: | |
| V2S_EXAMPLES = V2T_EXAMPLES[: min(4, len(V2T_EXAMPLES))] | |
| I2S_EXAMPLES = _load_media_examples("i2s", {".png", ".jpg", ".jpeg", ".webp"}) | |
| LOGO_DATA_URI = _load_logo_data() | |
| MMU_IMAGE = DEMO_ROOT / "mmu" / "1.jpg" | |
| # MMU_IMAGE_ALT = DEMO_ROOT / "mmu" / "SD_IMG_00235_1.png" | |
| if MMU_IMAGE.exists(): | |
| MMU_EXAMPLES = [ | |
| [ | |
| str(MMU_IMAGE), | |
| "Describe the scene in this image in detail.", | |
| ] | |
| ] | |
| else: | |
| MMU_EXAMPLES = [] | |
| if not I2S_EXAMPLES and MMU_EXAMPLES: | |
| I2S_EXAMPLES = [[example[0]] for example in MMU_EXAMPLES] | |
| def _render_response(status: str, body_html: str = "") -> str: | |
| safe_status = html.escape(status or "") | |
| parts = [] | |
| if safe_status: | |
| parts.append(f"<p class='omada-response-status'>{safe_status}</p>") | |
| if body_html: | |
| parts.append(body_html) | |
| content = "".join(parts) | |
| return f"<div class='omada-response-container'>{content}</div>" | |
| def _render_text_message(status: str, content: Optional[str]) -> str: | |
| # post-processing | |
| if content: | |
| remove_tokens = ["<|eot_id|>", "<|eot_id>", "<eot_id|>", "<eot_id>", "<eot>", "assistant"] | |
| for token in remove_tokens: | |
| content = content.replace(token, "") | |
| content = (content or "").strip() | |
| if not content: | |
| return _render_response(status) | |
| safe_content = html.escape(content).replace("\n", "<br>") | |
| body = f"<div class='omada-response-block'>{safe_content}</div>" | |
| return _render_response(status, body) | |
| def _render_audio_message(status: str, audio: Optional[Tuple[int, np.ndarray]]) -> str: | |
| """Render an inline HTML audio player for chat responses.""" | |
| if not audio: | |
| return _render_response(status) | |
| sample_rate, data = audio | |
| if data is None: | |
| return _render_response(status) | |
| waveform = np.asarray(data, dtype=np.float32) | |
| if waveform.size == 0: | |
| return _render_response(status) | |
| if waveform.ndim == 1: | |
| waveform = waveform[:, None] | |
| channels = waveform.shape[1] | |
| clipped = np.clip(waveform, -1.0, 1.0) | |
| pcm16 = (clipped * 32767.0).astype(np.int16) | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, "wb") as wav_writer: | |
| wav_writer.setnchannels(channels) | |
| wav_writer.setsampwidth(2) # 16-bit PCM | |
| wav_writer.setframerate(int(sample_rate)) | |
| wav_writer.writeframes(pcm16.tobytes()) | |
| encoded = base64.b64encode(buffer.getvalue()).decode("ascii") | |
| audio_tag = ( | |
| "<div class='omada-audio-block'>" | |
| "<audio controls preload='auto' playsinline>" | |
| f"<source src='data:audio/wav;base64,{encoded}' type='audio/wav' /></audio>" | |
| "</div>" | |
| ) | |
| return _render_response(status, audio_tag) | |
| def _render_image_message(status: str, image: Optional[Image.Image]) -> str: | |
| if image is None: | |
| return _render_response(status) | |
| buffer = io.BytesIO() | |
| try: | |
| image.save(buffer, format="PNG") | |
| except Exception: | |
| return _render_response(status) | |
| encoded = base64.b64encode(buffer.getvalue()).decode("ascii") | |
| image_html = ( | |
| "<div class='omada-response-block'>" | |
| "<img src='data:image/png;base64," | |
| f"{encoded}" | |
| "' alt='Generated image' style='max-width:100%;border-radius:12px;' />" | |
| "</div>" | |
| ) | |
| return _render_response(status, image_html) | |
| def _render_image_text_message(status: str, image: Optional[Image.Image], text: str) -> str: | |
| """Render combined text + image output for TI2TI.""" | |
| blocks = [] | |
| text_clean = (text or "").strip() | |
| if text_clean: | |
| safe_text = html.escape(text_clean).replace("\n", "<br>") | |
| blocks.append(f"<div class='omada-response-block'>{safe_text}</div>") | |
| if image is not None: | |
| buffer = io.BytesIO() | |
| try: | |
| image.save(buffer, format="PNG") | |
| encoded = base64.b64encode(buffer.getvalue()).decode("ascii") | |
| blocks.append( | |
| "<div class='omada-response-block'>" | |
| "<img src='data:image/png;base64," | |
| f"{encoded}" | |
| "' alt='Generated image' style='max-width:100%;border-radius:12px;' />" | |
| "</div>" | |
| ) | |
| except Exception: | |
| pass | |
| body = "".join(blocks) | |
| return _render_response(status, body if body else None) | |
| def _format_user_message(message: str) -> str: | |
| clean = html.escape(message or "") | |
| return clean.replace("\n", "<br>") | |
| # Ensure project modules (models, training, inference.common, …) are importable when the | |
| # script is launched directly via `python MMaDA/inference/gradio_multimodal_demo.py`. | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from omegaconf import DictConfig, OmegaConf | |
| from PIL import Image | |
| from inference.common import ( | |
| build_uni_prompting, | |
| get_vq_model_audio, | |
| get_vq_model_image, | |
| load_omada_from_checkpoint, | |
| load_train_config, | |
| ) | |
| from models import get_mask_schedule | |
| from models.modeling_omada import add_gumbel_noise, get_num_transfer_tokens | |
| from training.data import S2T_INSTRUCTION, T2S_INSTRUCTION, V2T_INSTRUCTION, V2S_INSTRUCTION | |
| from training.utils import image_transform | |
| def _cfg_get(cfg, key, default=None): | |
| """Lightweight helper to read DictConfig/dict/objects safely.""" | |
| if cfg is None: | |
| return default | |
| if isinstance(cfg, dict): | |
| return cfg.get(key, default) | |
| try: | |
| value = getattr(cfg, key) | |
| except AttributeError: | |
| value = None | |
| if value is not None: | |
| return value | |
| getter = getattr(cfg, "get", None) | |
| if callable(getter): | |
| try: | |
| value = getter(key) | |
| except TypeError: | |
| try: | |
| value = getter(key, default) | |
| except Exception: | |
| value = None | |
| if value is not None: | |
| return value | |
| return default | |
| def _resolve_noise_schedule(train_cfg) -> callable: | |
| """Return the diffusion noise schedule used for T2S sampling.""" | |
| schedule_cfg = getattr(train_cfg, "mask_schedule", None) | |
| if schedule_cfg and hasattr(schedule_cfg, "schedule"): | |
| schedule_name = schedule_cfg.schedule | |
| schedule_kwargs = schedule_cfg.get("params", {}) | |
| return get_mask_schedule(schedule_name, **schedule_kwargs) | |
| schedule_name = train_cfg.training.get("mask_schedule", "cosine") | |
| return get_mask_schedule(schedule_name) | |
| def _resolve_mask_schedule(cfg) -> callable: | |
| """Mirror training-time mask schedule resolution for image tasks.""" | |
| schedule_cfg = getattr(cfg, "mask_schedule", None) | |
| if isinstance(schedule_cfg, DictConfig): | |
| schedule_name = getattr(schedule_cfg, "schedule", None) | |
| params_cfg = getattr(schedule_cfg, "params", None) | |
| elif isinstance(schedule_cfg, dict): | |
| schedule_name = schedule_cfg.get("schedule") | |
| params_cfg = schedule_cfg.get("params") | |
| else: | |
| schedule_name = None | |
| params_cfg = None | |
| training_cfg = getattr(cfg, "training", None) | |
| if schedule_name is None: | |
| if training_cfg is None: | |
| schedule_name = "cosine" | |
| else: | |
| schedule_name = getattr(training_cfg, "mask_schedule", None) | |
| if schedule_name is None and isinstance(training_cfg, dict): | |
| schedule_name = training_cfg.get("mask_schedule") | |
| if schedule_name is None: | |
| schedule_name = "cosine" | |
| params = {} | |
| if params_cfg is not None: | |
| if isinstance(params_cfg, DictConfig): | |
| params = OmegaConf.to_container(params_cfg, resolve=True) or {} | |
| elif isinstance(params_cfg, dict): | |
| params = dict(params_cfg) | |
| else: | |
| params = params_cfg | |
| if not isinstance(params, dict): | |
| params = {} | |
| return get_mask_schedule(schedule_name, **params) | |
| class OmadaDemo: | |
| """Lightweight container that loads all inference assets once.""" | |
| def __init__(self, train_config: str, checkpoint: str, device: Optional[str] = None): | |
| ckpt_path = Path(checkpoint) | |
| if ckpt_path.name != "unwrapped_model": | |
| raise ValueError( | |
| "`--checkpoint` must point to an `unwrapped_model` directory. " | |
| f"Received: {checkpoint}" | |
| ) | |
| self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) | |
| self.train_cfg = load_train_config(train_config) | |
| self.uni_prompting, _ = build_uni_prompting(self.train_cfg) | |
| # Core models | |
| self.model = load_omada_from_checkpoint(str(ckpt_path), self.device) | |
| self.vq_audio = get_vq_model_audio(self.train_cfg, self.device) | |
| self.vq_image = get_vq_model_image(self.train_cfg, self.device) | |
| self.model.eval() | |
| self.vq_audio.eval() | |
| self.vq_image.eval() | |
| # Cached constants | |
| self.mask_token_id = int(self.model.config.mask_token_id) | |
| self.noise_schedule = _resolve_noise_schedule(self.train_cfg) | |
| self.sample_rate = int(getattr(self.vq_audio.u2s_config.data, "sampling_rate", 22050)) | |
| dataset_cfg = getattr(self.train_cfg, "dataset", None) | |
| dataset_params = getattr(dataset_cfg, "params", None) if dataset_cfg else None | |
| preprocess_cfg = getattr(dataset_cfg, "preprocessing", None) if dataset_cfg else None | |
| self.image_resolution = int(_cfg_get(dataset_params, "resolution", 336)) | |
| self.t2i_resolution = int(_cfg_get(dataset_params, "t2i_resolution", self.image_resolution)) | |
| mmu_cfg = _cfg_get(dataset_params, "mmu_interleaved") | |
| self.mmu_image_resolution = int(_cfg_get(mmu_cfg, "resolution", self.image_resolution)) | |
| self.video_resolution = 224 | |
| patch_size = int(_cfg_get(dataset_params, "vq_patch_size", 16)) | |
| self.image_seq_len = int( | |
| _cfg_get( | |
| dataset_params, | |
| "i2i_seq_len", | |
| max(1, (self.image_resolution // patch_size) ** 2), | |
| ) | |
| ) | |
| self.t2i_seq_len = int( | |
| _cfg_get( | |
| dataset_params, | |
| "t2i_seq_len", | |
| max(1, (self.t2i_resolution // patch_size) ** 2), | |
| ) | |
| ) | |
| speech_cfg = _cfg_get(dataset_params, "video_speech_dataset") | |
| self.num_frames_v2t = max(1, int(_cfg_get(speech_cfg, "num_frames_v2t", 5))) | |
| self.num_frames_v2s = max(1, int(_cfg_get(speech_cfg, "num_frames_v2s", 4))) | |
| self.genders = ['female', 'male'] | |
| self.emotions = ['angry', 'happy', 'neutral', 'sad'] | |
| self.speeds = ['normal', 'fast', 'slow'] | |
| self.pitches = ['normal', 'high', 'low'] | |
| # Pre-computed offsets reused across calls | |
| self.text_vocab_size = len(self.uni_prompting.text_tokenizer) | |
| # The current checkpoints assume a fixed 8k vision codebook and 4k speech codebook. | |
| self.codebook_size = 8192 | |
| self.speech_codebook = self.codebook_size | |
| self.audio_codebook_size = 4096 | |
| self.max_audio_len_short = int( | |
| getattr( | |
| self.uni_prompting, | |
| "max_audio_len_short", | |
| getattr(self.train_cfg.dataset.preprocessing, "max_aud_length_short", 256), | |
| ) | |
| ) | |
| self.max_text_len = int(getattr(self.train_cfg.dataset.preprocessing, "max_seq_length", 1024)) | |
| model_seq_len = getattr(self.model.config, "num_vq_tokens", None) | |
| if model_seq_len is None: | |
| model_seq_len = getattr(getattr(self.train_cfg.model, "omada", None), "num_vq_tokens", None) | |
| if model_seq_len is None: | |
| model_seq_len = getattr(getattr(self.train_cfg.model, "vq_model_image", None), "num_vq_tokens", None) | |
| if model_seq_len is not None: | |
| self.image_seq_len = min(self.image_seq_len, int(model_seq_len)) | |
| self.t2i_seq_len = min(self.t2i_seq_len, int(model_seq_len)) | |
| print(self.image_seq_len) | |
| self.image_noise_schedule = _resolve_noise_schedule(self.train_cfg) | |
| self.mask_schedule = _resolve_mask_schedule(self.train_cfg) | |
| training_cfg = getattr(self.train_cfg, "training", None) | |
| self.noise_type = _cfg_get(training_cfg, "noise_type", "mask") | |
| self.predict_all_tokens = bool(_cfg_get(training_cfg, "predict_all_tokens", False)) | |
| self.t2i_default_timesteps = int(_cfg_get(training_cfg, "generation_timesteps", 20)) | |
| self.i2i_default_timesteps = int(_cfg_get(training_cfg, "i2i_eval_timesteps", 24)) | |
| self.audio_condition_default = "gender-female_emotion-neutral_speed-normal_pitch-normal" | |
| style_map = getattr(getattr(self.vq_audio, "config", None), "u2s_style2idx", None) | |
| if isinstance(style_map, dict): | |
| self._valid_conditions = set(style_map.keys()) | |
| if self._valid_conditions and self.audio_condition_default not in self._valid_conditions: | |
| # Ensure the default condition is valid for the tokenizer. | |
| self.audio_condition_default = next(iter(self._valid_conditions)) | |
| else: | |
| self._valid_conditions = set() | |
| self._temp_video_files = [] | |
| # ------------------------------------------------------------------ | |
| # Text-to-Speech | |
| # ------------------------------------------------------------------ | |
| def run_t2s( | |
| self, | |
| text: str, | |
| max_new_tokens: int, | |
| steps: int, | |
| block_length: int, | |
| temperature: float, | |
| cfg_scale: float, | |
| gender_choice: str, | |
| emotion_choice: str, | |
| speed_choice: str, | |
| pitch_choice: str, | |
| ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: | |
| if text is None or not text.strip(): | |
| return None, "Please provide text to synthesize." | |
| speech_len, steps, block_length = self._prepare_block_schedule( | |
| max_new_tokens, | |
| steps, | |
| block_length, | |
| ) | |
| gender = self._resolve_choice(gender_choice, self.genders) | |
| emotion = self._resolve_choice(emotion_choice, self.emotions) | |
| speed = self._resolve_choice(speed_choice, self.speeds) | |
| pitch = self._resolve_choice(pitch_choice, self.pitches) | |
| text = text.strip().upper() | |
| prompt = ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{random.choice(T2S_INSTRUCTION)}\n{text}" | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| audio_tokens = torch.full( | |
| (1, speech_len), | |
| fill_value=self.mask_token_id, | |
| dtype=torch.long, | |
| device=self.device, | |
| ) | |
| input_ids, attention_mask = self.uni_prompting(([prompt], audio_tokens), "t2s_gen") | |
| input_ids = input_ids.to(self.device) | |
| attention_mask = attention_mask.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.t2s_generate_mmu_like( | |
| input_ids=input_ids, | |
| max_new_tokens=int(speech_len), | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| temperature=float(temperature), | |
| cfg_scale=float(cfg_scale), | |
| mask_token_id=self.mask_token_id, | |
| attention_mask=attention_mask, | |
| uni_prompting=self.uni_prompting, | |
| codebook_size=self.codebook_size, | |
| ) | |
| if not outputs: | |
| return None, "Generation produced no speech tokens." | |
| rel = outputs[0] | |
| if isinstance(rel, torch.Tensor): | |
| rel_ids = rel.detach().cpu().tolist() | |
| else: | |
| rel_ids = list(rel) | |
| if not rel_ids: | |
| return None, "Generation produced no speech tokens." | |
| speech_units = "".join(f"<|speech_{sid}|>" for sid in rel_ids) | |
| condition = f"gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}" | |
| wav = self.vq_audio.decode( | |
| speech_units, | |
| condition=condition, | |
| output_wav_file=os.path.join("/tmp", "omada_t2s.wav"), | |
| ) | |
| audio = (self.sample_rate, wav.astype(np.float32)) | |
| status = f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})." | |
| return audio, status | |
| # ------------------------------------------------------------------ | |
| # Speech-to-Speech | |
| # ------------------------------------------------------------------ | |
| def run_s2s( | |
| self, | |
| audio_path: Optional[str], | |
| max_new_tokens: int, | |
| steps: int, | |
| block_length: int, | |
| temperature: float, | |
| cfg_scale: float, | |
| ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: | |
| if not audio_path: | |
| return None, "Please upload source speech first." | |
| try: | |
| user_tokens = self.vq_audio.encode(audio_path) | |
| except Exception as exc: | |
| return None, f"Failed to encode input audio: {exc}" | |
| if not isinstance(user_tokens, torch.Tensor): | |
| user_tokens = torch.tensor(user_tokens) | |
| if user_tokens.dim() == 1: | |
| user_tokens = user_tokens.unsqueeze(0) | |
| user_tokens = user_tokens.to(self.device, dtype=torch.long) | |
| if user_tokens.numel() == 0: | |
| return None, "Uploaded speech clip produced no tokens." | |
| gen_len = max(1, int(max_new_tokens)) | |
| gen_len = min(gen_len, self.max_audio_len_short) | |
| gen_len, steps, block_length = self._prepare_block_schedule( | |
| gen_len, | |
| steps, | |
| block_length, | |
| ) | |
| offset = self.text_vocab_size + self.codebook_size | |
| user_shifted = user_tokens + offset | |
| assistant_placeholder = torch.full( | |
| (1, gen_len), | |
| self.mask_token_id, | |
| dtype=torch.long, | |
| device=self.device, | |
| ) | |
| input_ids, attention_mask = self.uni_prompting( | |
| ([user_shifted], [assistant_placeholder]), | |
| "s2s_gen", | |
| ) | |
| input_ids = input_ids.to(self.device) | |
| attention_mask = attention_mask.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.t2s_generate_mmu_like( | |
| input_ids=input_ids, | |
| max_new_tokens=gen_len, | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| temperature=float(temperature), | |
| cfg_scale=float(cfg_scale), | |
| mask_token_id=self.mask_token_id, | |
| attention_mask=attention_mask, | |
| uni_prompting=self.uni_prompting, | |
| codebook_size=self.codebook_size, | |
| audio_codebook_size=self.audio_codebook_size, | |
| ) | |
| if not outputs: | |
| return None, "Generation returned no tokens." | |
| generated = outputs[0] | |
| if isinstance(generated, torch.Tensor): | |
| generated = generated.detach().cpu() | |
| eoa_token_id = int(self.uni_prompting.sptids_dict['<|eoa|>'][0].item()) | |
| mask_token_id = int(self.mask_token_id) | |
| token_list = [] | |
| for tok in generated.tolist(): | |
| tok = int(tok) | |
| if tok < 0: | |
| continue | |
| if tok == eoa_token_id: | |
| break | |
| if tok == mask_token_id: | |
| continue | |
| if tok >= self.audio_codebook_size: | |
| continue | |
| token_list.append(tok) | |
| if not token_list: | |
| return None, "Generated sequence was empty after post-processing." | |
| speech_units = "".join(f"<|speech_{tok}|>" for tok in token_list) | |
| condition = self._resolve_condition(self.audio_condition_default) | |
| fd, temp_path = tempfile.mkstemp(prefix="omada_s2s_", suffix=".wav") | |
| os.close(fd) | |
| try: | |
| wav = self.vq_audio.decode( | |
| speech_units, | |
| condition=condition, | |
| output_wav_file=temp_path, | |
| ) | |
| finally: | |
| try: | |
| os.remove(temp_path) | |
| except OSError: | |
| pass | |
| audio = (self.sample_rate, wav.astype(np.float32)) | |
| return audio, f"Speech response generated successfully. (voice: {condition})" | |
| # ------------------------------------------------------------------ | |
| # Speech-to-Text | |
| # ------------------------------------------------------------------ | |
| def run_s2t( | |
| self, | |
| audio_path: Optional[str], | |
| steps: int, | |
| block_length: int, | |
| max_new_tokens: int, | |
| remasking: str, | |
| ) -> Tuple[str, str]: | |
| if not audio_path: | |
| return "", "Please upload an audio file first." | |
| tokens = self.vq_audio.encode(audio_path).to(self.device) | |
| offset = self.text_vocab_size + self.speech_codebook | |
| tokens = tokens + offset | |
| spt = self.uni_prompting.sptids_dict | |
| audio_block = torch.cat( | |
| [ | |
| spt['<|s2t|>'].to(self.device).unsqueeze(0), | |
| spt['<|soa|>'].to(self.device).unsqueeze(0), | |
| tokens.to(self.device), | |
| spt['<|eoa|>'].to(self.device).unsqueeze(0), | |
| ], | |
| dim=1, | |
| ) | |
| prompt_text = random.choice(S2T_INSTRUCTION) | |
| chat_prompt = ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{prompt_text}" | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| prompt_tensor = self.uni_prompting.text_tokenizer( | |
| chat_prompt, | |
| return_tensors="pt", | |
| ).input_ids.to(self.device) | |
| input_ids = torch.cat([audio_block, prompt_tensor], dim=1) | |
| with torch.no_grad(): | |
| output_ids = self.model.mmu_generate( | |
| input_ids, | |
| max_new_tokens=int(max_new_tokens), | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| remasking=str(remasking), | |
| ) | |
| decoded = self.uni_prompting.text_tokenizer.batch_decode( | |
| output_ids[:, input_ids.shape[1]:], | |
| skip_special_tokens=True, | |
| )[0] | |
| return decoded.strip(), "Transcription generated successfully." | |
| # ------------------------------------------------------------------ | |
| # Video-to-Text | |
| # ------------------------------------------------------------------ | |
| def run_v2t( | |
| self, | |
| video_path: Any, | |
| steps: int, | |
| block_length: int, | |
| max_new_tokens: int, | |
| ) -> Tuple[str, str]: | |
| resolved_path, converted = self._prepare_video_path(video_path) | |
| if not resolved_path: | |
| return "", "Please upload or record a video file first." | |
| try: | |
| video_tokens = self._extract_video_tokens(resolved_path, num_frame=self.num_frames_v2t) | |
| except Exception as exc: | |
| return "", f"Failed to process video: {exc}" | |
| spt = self.uni_prompting.sptids_dict | |
| # Match training eval: fixed detailed-description question | |
| question = "Please provide a detailed description of the video." | |
| prompt = ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{question}" | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| prompt_ids = self.uni_prompting.text_tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| ).input_ids.to(self.device) | |
| input_ids = torch.cat( | |
| [ | |
| spt['<|v2t|>'].to(self.device).unsqueeze(0), | |
| spt['<|soi|>'].to(self.device).unsqueeze(0), | |
| video_tokens, | |
| spt['<|eoi|>'].to(self.device).unsqueeze(0), | |
| spt['<|sot|>'].to(self.device).unsqueeze(0), | |
| prompt_ids, | |
| ], | |
| dim=1, | |
| ).long() | |
| with torch.no_grad(): | |
| output_ids = self.model.mmu_generate( | |
| input_ids, | |
| max_new_tokens=int(max_new_tokens), | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| ) | |
| print("[V2T] input shape:", input_ids.shape) | |
| print("[V2T] output shape:", output_ids.shape) | |
| raw_all = self.uni_prompting.text_tokenizer.decode(output_ids[0], skip_special_tokens=False) | |
| print("[V2T] RAW ALL:", repr(raw_all)) | |
| decoded = self.uni_prompting.text_tokenizer.batch_decode( | |
| output_ids[:, input_ids.shape[1]:], | |
| skip_special_tokens=True, | |
| )[0] | |
| print("[V2T] DECODED SLICE:", repr(decoded)) | |
| return decoded.strip(), "Video caption generated successfully." | |
| # ------------------------------------------------------------------ | |
| # Text-to-Image | |
| # ------------------------------------------------------------------ | |
| def run_t2i( | |
| self, | |
| prompt: str, | |
| timesteps: int, | |
| temperature: float, | |
| guidance_scale: float, | |
| ) -> Tuple[Optional[Image.Image], str]: | |
| if not prompt or not prompt.strip(): | |
| return None, "Please provide a text prompt." | |
| image_seq_len = 1024 | |
| image_tokens = torch.full( | |
| (1, image_seq_len), | |
| self.mask_token_id, | |
| dtype=torch.long, | |
| device=self.device, | |
| ) | |
| input_ids, attention_mask = self.uni_prompting(([prompt.strip()], image_tokens), "t2i_gen") | |
| input_ids = input_ids.to(self.device) | |
| attention_mask = attention_mask.to(self.device) | |
| if guidance_scale > 0: | |
| uncond_ids, uncond_mask = self.uni_prompting(([""], image_tokens.clone()), "t2i_gen") | |
| uncond_ids = uncond_ids.to(self.device) | |
| uncond_mask = uncond_mask.to(self.device) | |
| else: | |
| uncond_ids = None | |
| uncond_mask = None | |
| with torch.no_grad(): | |
| gen_tokens = self.model.t2i_generate( | |
| input_ids=input_ids, | |
| uncond_input_ids=uncond_ids, | |
| attention_mask=attention_mask, | |
| uncond_attention_mask=uncond_mask, | |
| guidance_scale=float(guidance_scale), | |
| temperature=float(temperature), | |
| timesteps=int(timesteps), | |
| noise_schedule=self.mask_schedule, | |
| noise_type=self.noise_type, | |
| predict_all_tokens=self.predict_all_tokens, | |
| seq_len=image_seq_len, | |
| mask_token_id=self.mask_token_id, | |
| codebook_size=self.codebook_size, | |
| uni_prompting=self.uni_prompting, | |
| config=self.train_cfg, | |
| ) | |
| if gen_tokens is None: | |
| return None, "Image generation failed." | |
| gen_tokens = torch.clamp(gen_tokens, min=0, max=self.codebook_size - 1) | |
| image = self._decode_image_tokens(gen_tokens[0]) | |
| return image, "Image generated from text prompt." | |
| # ------------------------------------------------------------------ | |
| # Image-to-Image Editing | |
| # ------------------------------------------------------------------ | |
| def run_i2i( | |
| self, | |
| instruction: str, | |
| source_image: Optional[Image.Image], | |
| timesteps: int, | |
| temperature: float, | |
| guidance_scale: float, | |
| ) -> Tuple[Optional[Image.Image], str]: | |
| if source_image is None: | |
| return None, "Please upload a reference image." | |
| if not instruction or not instruction.strip(): | |
| return None, "Provide editing instructions for the image." | |
| try: | |
| input_tokens = self._prepare_image_tokens(source_image, resolution=self.image_resolution) | |
| except Exception as exc: | |
| return None, f"Failed to encode input image: {exc}" | |
| seq_len = int(input_tokens.shape[-1]) | |
| output_placeholder = torch.full( | |
| (1, seq_len), | |
| self.mask_token_id, | |
| dtype=torch.long, | |
| device=self.device, | |
| ) | |
| input_ids, attention_mask = self.uni_prompting( | |
| ([instruction.strip()], input_tokens, output_placeholder), | |
| "i2i_gen", | |
| ) | |
| input_ids = input_ids.to(self.device) | |
| attention_mask = attention_mask.to(self.device) | |
| uncond_ids = None | |
| uncond_attn = None | |
| if guidance_scale > 0: | |
| uncond_ids, uncond_attn = self.uni_prompting( | |
| ([""], input_tokens.clone(), torch.full_like(output_placeholder, self.mask_token_id)), | |
| "i2i_gen", | |
| ) | |
| uncond_ids = uncond_ids.to(self.device) | |
| uncond_attn = uncond_attn.to(self.device) | |
| with torch.no_grad(): | |
| gen_tokens = self.model.i2i_generate( | |
| input_ids=input_ids, | |
| uncond_input_ids=uncond_ids, | |
| attention_mask=attention_mask, | |
| uncond_attention_mask=uncond_attn, | |
| temperature=float(temperature), | |
| timesteps=int(timesteps), | |
| guidance_scale=float(guidance_scale), | |
| noise_schedule=self.mask_schedule, | |
| noise_type=self.noise_type, | |
| seq_len=seq_len, | |
| mask_token_id=self.mask_token_id, | |
| codebook_size=self.codebook_size, | |
| uni_prompting=self.uni_prompting, | |
| config=self.train_cfg, | |
| ) | |
| if gen_tokens is None: | |
| return None, "Image editing failed." | |
| gen_tokens = torch.clamp(gen_tokens, min=0, max=self.codebook_size - 1) | |
| image = self._decode_image_tokens(gen_tokens[0]) | |
| return image, "Edited image generated." | |
| # ------------------------------------------------------------------ | |
| # Video-to-Speech | |
| # ------------------------------------------------------------------ | |
| def run_v2s( | |
| self, | |
| video_path: Any, | |
| message: Optional[str], | |
| max_new_tokens: int, | |
| steps: int, | |
| block_length: int, | |
| temperature: float, | |
| cfg_scale: float, | |
| ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: | |
| resolved_path, converted = self._prepare_video_path(video_path) | |
| if not resolved_path: | |
| return None, "Please upload or record a video first." | |
| try: | |
| video_tokens = self._extract_video_tokens(resolved_path, num_frame=self.num_frames_v2s) | |
| except Exception as exc: | |
| return None, f"Failed to process video: {exc}" | |
| prompt_body = message.strip() if message and message.strip() else random.choice(V2S_INSTRUCTION) | |
| prompt_text = ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{prompt_body}" | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| gen_len = max(1, int(max_new_tokens)) | |
| gen_len = min(gen_len, self.max_audio_len_short) | |
| gen_len, steps, block_length = self._prepare_block_schedule( | |
| gen_len, | |
| steps, | |
| block_length, | |
| ) | |
| audio_placeholder = torch.full( | |
| (1, gen_len), | |
| self.mask_token_id, | |
| dtype=torch.long, | |
| device=self.device, | |
| ) | |
| try: | |
| seq_ids, attn_mask = self.uni_prompting( | |
| (video_tokens, [prompt_text], [audio_placeholder]), | |
| "v2s_gen", | |
| ) | |
| except Exception as exc: | |
| return None, f"Failed to build V2S prompt: {exc}" | |
| input_ids = seq_ids.to(self.device) | |
| attn_mask = attn_mask.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.t2s_generate_mmu_like( | |
| input_ids=input_ids, | |
| max_new_tokens=gen_len, | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| temperature=float(temperature), | |
| cfg_scale=float(cfg_scale), | |
| mask_token_id=self.mask_token_id, | |
| attention_mask=attn_mask, | |
| uni_prompting=self.uni_prompting, | |
| codebook_size=self.codebook_size, | |
| audio_codebook_size=self.audio_codebook_size, | |
| ) | |
| if not outputs: | |
| return None, "Audio generation produced no tokens." | |
| generated = outputs[0] | |
| if isinstance(generated, torch.Tensor): | |
| generated = generated.detach().cpu() | |
| eoa_token_id = int(self.uni_prompting.sptids_dict['<|eoa|>'][0].item()) | |
| mask_token_id = int(self.mask_token_id) | |
| token_list = [] | |
| for tok in generated.tolist(): | |
| tok = int(tok) | |
| if tok < 0: | |
| continue | |
| if tok == eoa_token_id: | |
| break | |
| if tok == mask_token_id: | |
| continue | |
| if tok >= self.audio_codebook_size: | |
| continue | |
| token_list.append(tok) | |
| if not token_list: | |
| return None, "Generated sequence was empty after decoding." | |
| speech_units = "".join(f"<|speech_{tok}|>" for tok in token_list) | |
| fd, temp_path = tempfile.mkstemp(prefix="omada_v2s_", suffix=".wav") | |
| os.close(fd) | |
| condition = self._resolve_condition(self.audio_condition_default) | |
| try: | |
| wav = self.vq_audio.decode( | |
| speech_units, | |
| condition=condition, | |
| output_wav_file=temp_path, | |
| ) | |
| except Exception as exc: | |
| return None, f"Failed to decode speech: {exc}" | |
| finally: | |
| try: | |
| os.remove(temp_path) | |
| except OSError: | |
| pass | |
| status = "Speech generated from video." | |
| if converted: | |
| status += " (Webcam recording converted to MP4.)" | |
| status += f" (voice: {condition})" | |
| return (self.sample_rate, wav.astype(np.float32)), status | |
| # ------------------------------------------------------------------ | |
| # Image-to-Speech (This is a subset of s2s) | |
| # ------------------------------------------------------------------ | |
| def run_i2s( | |
| self, | |
| image: Optional[Image.Image], | |
| message: Optional[str], | |
| max_new_tokens: int, | |
| steps: int, | |
| block_length: int, | |
| temperature: float, | |
| cfg_scale: float, | |
| ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: | |
| if image is None: | |
| return None, "Please upload an image first." | |
| question = (message or "").strip() or "Please describe the image in spoken form." | |
| caption, status = self._mmu_answer([image], question) | |
| if not caption: | |
| return None, status | |
| speech_len, steps, block_length = self._prepare_block_schedule( | |
| max_new_tokens, | |
| steps, | |
| block_length, | |
| ) | |
| audio_result, speech_status = self.run_t2s( | |
| caption, | |
| speech_len, | |
| steps, | |
| block_length, | |
| temperature, | |
| cfg_scale, | |
| 'random', | |
| 'random', | |
| 'random', | |
| 'random', | |
| ) | |
| if audio_result is None: | |
| return None, speech_status | |
| combined_status = f"{status} {speech_status}".strip() | |
| return audio_result, combined_status or "Spoken description generated." | |
| # ------------------------------------------------------------------ | |
| # Chat (Text Generation) | |
| # ------------------------------------------------------------------ | |
| def run_chat( | |
| self, | |
| message: str, | |
| max_new_tokens: int, | |
| steps: int, | |
| block_length: int, | |
| temperature: float, | |
| ) -> Tuple[str, str]: | |
| content = (message or "").strip() | |
| if not content: | |
| return "", "Type a message to start chatting." | |
| prompt = ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{content}" | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| tokenizer = self.uni_prompting.text_tokenizer | |
| tokenizer.padding_side = "left" | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| tokens = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| ) | |
| input_ids = tokens["input_ids"].to(self.device) | |
| attn_mask = tokens.get("attention_mask") | |
| if attn_mask is not None: | |
| attn_mask = attn_mask.to(self.device) | |
| with torch.no_grad(): | |
| output_ids = self._generate_text_tokens( | |
| input_ids, | |
| max_new_tokens=int(max_new_tokens), | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| temperature=float(temperature), | |
| cfg_scale=0.0, | |
| attention_mask=attn_mask, | |
| ) | |
| decoded = tokenizer.batch_decode( | |
| output_ids[:, input_ids.shape[1]:], | |
| skip_special_tokens=True, | |
| )[0] | |
| return decoded.strip(), "Assistant reply generated." | |
| # ------------------------------------------------------------------ | |
| # General MMU (N Images → Text) | |
| # ------------------------------------------------------------------ | |
| def run_mmu( | |
| self, | |
| images: Union[Optional[Image.Image], Sequence[Optional[Image.Image]]], | |
| message: str, | |
| max_new_tokens: int, | |
| steps: int, | |
| block_length: int, | |
| temperature: float, | |
| ) -> Tuple[str, str]: | |
| """ | |
| MMU demo now consumes exactly one image. If callers pass a list (for | |
| backwards compatibility), we keep only the first valid image. | |
| """ | |
| if isinstance(images, Image.Image): | |
| normalized: List[Image.Image] = [images] | |
| elif images is None: | |
| normalized = [] | |
| else: | |
| normalized = [img for img in images if img is not None] | |
| if not normalized: | |
| return "", "Please provide an image for MMU reasoning." | |
| primary_image = normalized[0] | |
| reply, status = self._mmu_answer( | |
| [primary_image], | |
| message, | |
| max_new_tokens=max_new_tokens, | |
| steps=steps, | |
| block_length=block_length, | |
| temperature=temperature, | |
| ) | |
| return reply, status | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def _resolve_choice(self, choice: Optional[str], options): | |
| if choice is None or choice == 'random': | |
| return random.choice(options) | |
| return choice | |
| def _sample_condition(self) -> str: | |
| return ( | |
| f"gender-{random.choice(self.genders)}" | |
| f"_emotion-{random.choice(self.emotions)}" | |
| f"_speed-{random.choice(self.speeds)}" | |
| f"_pitch-{random.choice(self.pitches)}" | |
| ) | |
| def _resolve_condition(self, preferred: Optional[str] = None) -> str: | |
| if preferred and preferred != "random": | |
| if not self._valid_conditions or preferred in self._valid_conditions: | |
| return preferred | |
| if self._valid_conditions: | |
| for _ in range(8): | |
| candidate = self._sample_condition() | |
| if candidate in self._valid_conditions: | |
| return candidate | |
| if self.audio_condition_default in self._valid_conditions: | |
| return self.audio_condition_default | |
| return next(iter(self._valid_conditions)) | |
| if preferred and preferred != "random": | |
| return preferred | |
| return self.audio_condition_default | |
| def _format_chat_prompt(self, content: str) -> str: | |
| clean = (content or "").strip() | |
| if not clean: | |
| clean = "Please describe the visual content." | |
| return ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{clean}\n" | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| def _prepare_block_schedule( | |
| self, | |
| total_tokens: int, | |
| steps: int, | |
| block_length: int, | |
| ) -> Tuple[int, int, int]: | |
| total = max(1, int(total_tokens)) | |
| blk = max(1, min(int(block_length), total)) | |
| if total % blk != 0: | |
| blk = math.gcd(total, blk) | |
| blk = blk if blk > 0 else total | |
| if total % blk != 0: | |
| blk = total | |
| num_blocks = max(1, total // blk) | |
| steps = max(num_blocks, int(steps)) | |
| if steps % num_blocks != 0: | |
| steps = num_blocks * math.ceil(steps / num_blocks) | |
| return total, steps, blk | |
| def _prepare_video_path(self, video_input: Any) -> Tuple[Optional[str], bool]: | |
| """Normalize Gradio video inputs (upload/webcam) to an MP4 filepath.""" | |
| candidate = None | |
| if isinstance(video_input, str): | |
| candidate = video_input | |
| elif isinstance(video_input, dict): | |
| candidate = ( | |
| video_input.get("video") | |
| or video_input.get("name") | |
| or video_input.get("path") | |
| ) | |
| elif isinstance(video_input, (list, tuple)) and video_input: | |
| candidate = str(video_input[0]) | |
| if not candidate: | |
| return None, False | |
| candidate = str(candidate) | |
| if not self._ensure_file_ready(candidate): | |
| return None, False | |
| if candidate.lower().endswith(".mp4"): | |
| return candidate, False | |
| converted = self._convert_to_mp4(candidate) | |
| if converted: | |
| return converted, True | |
| suffix = Path(candidate).suffix or ".webm" | |
| fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) | |
| os.close(fd) | |
| try: | |
| shutil.copy2(candidate, tmp_path) | |
| self._temp_video_files.append(tmp_path) | |
| return tmp_path, False | |
| except OSError: | |
| try: | |
| os.remove(tmp_path) | |
| except OSError: | |
| pass | |
| return candidate, False | |
| if candidate.lower().endswith(".mp4"): | |
| return candidate, False | |
| converted = self._convert_to_mp4(candidate) | |
| if converted: | |
| return converted, True | |
| suffix = Path(candidate).suffix or ".webm" | |
| fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) | |
| os.close(fd) | |
| try: | |
| shutil.copy2(candidate, tmp_path) | |
| self._temp_video_files.append(tmp_path) | |
| return tmp_path, False | |
| except OSError: | |
| try: | |
| os.remove(tmp_path) | |
| except OSError: | |
| pass | |
| return candidate, False | |
| def _ensure_file_ready(self, path: str, retries: int = 8, delay: float = 0.2) -> bool: | |
| """Ensure the uploaded/recorded file is fully written before processing.""" | |
| prev_size = -1 | |
| for _ in range(retries): | |
| try: | |
| size = os.path.getsize(path) | |
| except OSError: | |
| size = -1 | |
| if size <= 0: | |
| time.sleep(delay) | |
| continue | |
| if size == prev_size: | |
| return True | |
| prev_size = size | |
| time.sleep(delay) | |
| return prev_size > 0 | |
| def _convert_to_mp4(self, src_path: str) -> Optional[str]: | |
| """Convert arbitrary video file to MP4 using OpenCV (drops audio).""" | |
| cap = cv2.VideoCapture(src_path) | |
| if not cap.isOpened(): | |
| return None | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if width <= 0 or height <= 0: | |
| cap.release() | |
| return None | |
| if not fps or np.isnan(fps) or fps <= 0: | |
| fps = 24.0 | |
| fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_", suffix=".mp4") | |
| os.close(fd) | |
| writer = cv2.VideoWriter( | |
| tmp_path, | |
| cv2.VideoWriter_fourcc(*"mp4v"), | |
| float(fps), | |
| (width, height), | |
| ) | |
| if not writer.isOpened(): | |
| cap.release() | |
| try: | |
| os.remove(tmp_path) | |
| except OSError: | |
| pass | |
| return None | |
| frame_count = 0 | |
| try: | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| writer.write(frame) | |
| frame_count += 1 | |
| finally: | |
| cap.release() | |
| writer.release() | |
| if frame_count == 0: | |
| try: | |
| os.remove(tmp_path) | |
| except OSError: | |
| pass | |
| return None | |
| self._temp_video_files.append(tmp_path) | |
| return tmp_path | |
| def _extract_video_tokens(self, video_path: str, num_frame=8) -> torch.Tensor: | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames <= 0: | |
| cap.release() | |
| raise RuntimeError(f"No readable frames in {video_path}") | |
| indices = np.linspace(0, total_frames - 1, num_frame, dtype=int) | |
| frames = [] | |
| for idx in range(total_frames): | |
| ret, frame = cap.read() | |
| if idx in indices and ret: | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil = Image.fromarray(rgb) | |
| frames.append(image_transform(pil, resolution=self.video_resolution)) | |
| cap.release() | |
| if len(frames) == 0: | |
| raise RuntimeError("Failed to sample frames for V2T inference.") | |
| video_tensor = torch.stack(frames).to(self.device) | |
| video_tokens = self.vq_image.get_code(video_tensor) + self.text_vocab_size | |
| return video_tokens.long().to(self.device).view(1, -1) | |
| def _prepare_image_tokens(self, image: Image.Image, resolution: Optional[int] = None) -> torch.LongTensor: | |
| if image is None: | |
| raise ValueError("Image input is required.") | |
| target_resolution = int(resolution or self.image_resolution) | |
| tensor = image_transform(image, resolution=target_resolution) | |
| tensor = tensor.unsqueeze(0).to(self.device) | |
| codes = self.vq_image.get_code(tensor) + self.text_vocab_size | |
| return codes.long().to(self.device) | |
| def _decode_image_tokens(self, tokens: torch.Tensor) -> Image.Image: | |
| codes = tokens.view(1, -1).clamp(min=0, max=self.codebook_size - 1).to(self.device) | |
| with torch.no_grad(): | |
| image_tensor = self.vq_image.decode_code(codes) | |
| image_tensor = image_tensor.squeeze(0).cpu() | |
| image_tensor = torch.clamp((image_tensor.float() + 1.0) / 2.0, min=0.0, max=1.0) | |
| array = (image_tensor.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) | |
| return Image.fromarray(array) | |
| def _mmu_answer( | |
| self, | |
| images: List[Image.Image], | |
| question: str, | |
| max_new_tokens: Optional[int] = None, | |
| steps: Optional[int] = None, | |
| block_length: Optional[int] = None, | |
| temperature: Optional[float] = None, | |
| ) -> Tuple[str, str]: | |
| if not images: | |
| return "", "Please provide at least one image." | |
| encoded_images: List[torch.Tensor] = [] | |
| for image in images: | |
| if image is None: | |
| continue | |
| try: | |
| tokens = self._prepare_image_tokens( | |
| image, | |
| resolution=480 | |
| ).view(-1) | |
| encoded_images.append(tokens) | |
| except Exception: | |
| continue | |
| if not encoded_images: | |
| return "", "Failed to encode the provided image(s)." | |
| question = (question or "").strip() or "Describe the visual content." | |
| prompt = self._format_chat_prompt(question) | |
| try: | |
| tokenized = self.uni_prompting.text_tokenizer( | |
| [prompt], | |
| add_special_tokens=False, | |
| )["input_ids"][0] | |
| except Exception as exc: | |
| return "", f"Failed to tokenize question: {exc}" | |
| try: | |
| mmu_input_ids, prompt_masks, _ = self.uni_prompting.mmu_mult_prompt( | |
| batch_image_ids_list=[encoded_images], | |
| batch_text_ids=[tokenized], | |
| ) | |
| except Exception as exc: | |
| return "", f"Failed to construct MMU prompt: {exc}" | |
| mmu_input_ids = mmu_input_ids.to(self.device) | |
| prompt_masks = prompt_masks.to(self.device) | |
| answer_tokens = int((prompt_masks == 0).sum(dim=1).max().item()) | |
| default_budget = max(1, answer_tokens) if answer_tokens > 0 else min(self.max_text_len, 256) | |
| gen_tokens = int(max_new_tokens or default_budget) | |
| requested_steps = steps if steps is not None else gen_tokens | |
| requested_block = block_length if block_length is not None else max(1, gen_tokens // 2) | |
| gen_tokens, steps, block_length = self._prepare_block_schedule( | |
| gen_tokens, | |
| requested_steps, | |
| requested_block, | |
| ) | |
| temperature = float(temperature if temperature is not None else 0.7) | |
| if gen_tokens > 0: | |
| mask_block = torch.full( | |
| (mmu_input_ids.size(0), gen_tokens), | |
| self.mask_token_id, | |
| dtype=torch.long, | |
| device=self.device, | |
| ) | |
| mmu_input_ids = torch.cat([mmu_input_ids, mask_block], dim=1) | |
| with torch.no_grad(): | |
| output_ids = self.model.mmu_generate( | |
| mmu_input_ids, | |
| max_new_tokens=int(gen_tokens), | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| temperature=temperature, | |
| remasking="low_confidence", | |
| mask_id=self.mask_token_id, | |
| ) | |
| decoded = self.uni_prompting.text_tokenizer.batch_decode( | |
| output_ids[:, mmu_input_ids.shape[1]:], | |
| skip_special_tokens=True, | |
| )[0].strip() | |
| if not decoded: | |
| return "", "MMU response was empty." | |
| return decoded, "Image understanding succeeded." | |
| def _generate_text_tokens( | |
| self, | |
| prompt_ids: torch.Tensor, | |
| max_new_tokens: int, | |
| steps: int, | |
| block_length: int, | |
| temperature: float, | |
| cfg_scale: float = 0.0, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| remasking: str = "low_confidence", | |
| ) -> torch.Tensor: | |
| prompt_ids = prompt_ids.to(self.device) | |
| batch_size, prompt_len = prompt_ids.shape | |
| gen_len, steps, block_length = self._prepare_block_schedule( | |
| max_new_tokens, | |
| steps, | |
| block_length, | |
| ) | |
| work = torch.full( | |
| (batch_size, prompt_len + gen_len), | |
| self.mask_token_id, | |
| dtype=torch.long, | |
| device=self.device, | |
| ) | |
| work[:, :prompt_len] = prompt_ids | |
| prompt_index = work != self.mask_token_id | |
| attention_bias = None | |
| if attention_mask is not None and (attention_mask == 0).any(): | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| num_blocks = max(1, gen_len // block_length) | |
| inner_steps = max(1, steps // num_blocks) | |
| for block_idx in range(num_blocks): | |
| block_slice = slice(prompt_len + block_idx * block_length, prompt_len + (block_idx + 1) * block_length) | |
| block_mask_index = work[:, block_slice] == self.mask_token_id | |
| num_transfer_tokens = get_num_transfer_tokens(block_mask_index, inner_steps) | |
| for inner_step in range(inner_steps): | |
| mask_index = work == self.mask_token_id | |
| if cfg_scale > 0.0: | |
| unconditional = work.clone() | |
| unconditional[prompt_index] = self.mask_token_id | |
| model_input = torch.cat([work, unconditional], dim=0) | |
| logits = self.model(model_input).logits | |
| cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) | |
| logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits) | |
| else: | |
| logits = self.model(work, attention_bias=attention_bias).logits | |
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
| x0 = torch.argmax(logits_with_noise, dim=-1) | |
| if remasking == "low_confidence": | |
| probs = F.softmax(logits.to(torch.float64), dim=-1) | |
| x0_p = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) | |
| elif remasking == "random": | |
| x0_p = torch.rand_like(x0, dtype=torch.float64) | |
| else: | |
| raise NotImplementedError(remasking) | |
| x0_p[:, prompt_len + (block_idx + 1) * block_length :] = -np.inf | |
| x0 = torch.where(mask_index, x0, work) | |
| confidence = torch.where(mask_index, x0_p, torch.full_like(x0_p, float('-inf'))) | |
| transfer_index = torch.zeros_like(work, dtype=torch.bool) | |
| for b in range(batch_size): | |
| k = int(num_transfer_tokens[b, inner_step].item()) | |
| if k <= 0: | |
| continue | |
| values, select_idx = torch.topk(confidence[b], k=k) | |
| transfer_index[b, select_idx] = values != float('-inf') | |
| work[transfer_index] = x0[transfer_index] | |
| return work | |
| def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optional[int]): | |
| theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray") | |
| with gr.Blocks(title="AIDAS Lab @ SNU", css=CUSTOM_CSS, theme=theme, js=FORCE_LIGHT_MODE_JS) as demo: | |
| with gr.Column(elem_classes=["omada-header"]): | |
| if LOGO_DATA_URI: | |
| gr.HTML( | |
| f"<div class='omada-logo-wrapper'><img src=\"{LOGO_DATA_URI}\" alt=\"AIDAS Lab\" class=\"omada-logo-img\" /></div>" | |
| ) | |
| elif LOGO_PATH.exists(): | |
| gr.Image( | |
| value=str(LOGO_PATH), | |
| show_label=False, | |
| height=140, | |
| interactive=False, | |
| elem_classes=["omada-logo"], | |
| ) | |
| gr.Markdown( | |
| "## Omni-modal Diffusion Foundation Model\n### Pretrained Demo", | |
| elem_classes=["omada-title"], | |
| ) | |
| gr.HTML( | |
| "<p class='omada-tagline'>" | |
| "<span class='tagline-speech'>Create speech</span>, " | |
| "<span class='tagline-audio'>transcribe audio</span>, " | |
| "<span class='tagline-video'>describe video</span>, " | |
| "<span class='tagline-text'>chat with text</span>, and " | |
| "<span class='tagline-image'>generate or edit images</span> — all from a single model. " | |
| "Use the advanced sections when you want tighter control." | |
| "</p>") | |
| group_to_modes = { | |
| "Speech": ["Text → Speech", "Speech → Speech", "Speech → Text"], | |
| "Video": ["Video → Text", "Video → Speech"], | |
| "Image": ["Text → Image", "Image Editing"], | |
| "Multi-Modal": ["MMU (Image → Text)"], | |
| "Text": ["Text"], | |
| } | |
| default_group = "Speech" | |
| default_mode = group_to_modes[default_group][0] | |
| placeholder_map = { | |
| "Text → Speech": "Type the speech you want to generate...", | |
| "Speech → Speech": "Optionally add context for the reply...", | |
| "Speech → Text": "Upload audio on the right, then leave notes here if needed.", | |
| "Video → Speech": "Upload video on the right. Optionally provide guidance here.", | |
| "Video → Text": "Upload video on the right, then leave notes here if needed.", | |
| "Text": "Ask anything and the assistant will reply with text.", | |
| "MMU (Image → Text)": "Ask a question about the uploaded image.", | |
| "Text → Image": "Describe the image you want to generate...", | |
| "Image Editing": "Describe how you want to edit the uploaded image...", | |
| } | |
| with gr.Row(elem_classes=["omada-layout"], equal_height=False): | |
| with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]): | |
| chatbox = gr.Chatbot(label="Session", height=760, sanitize_html=False) | |
| chat_input = gr.Textbox( | |
| label="Message", | |
| placeholder=placeholder_map[default_mode], | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| send_button = gr.Button("Send", variant="primary") | |
| clear_button = gr.Button("Clear", variant="secondary") | |
| with gr.Column(scale=2, min_width=360, elem_classes=["omada-controls"]): | |
| task_group_selector = gr.Radio( | |
| list(group_to_modes.keys()), | |
| value=default_group, | |
| label="Task Group", | |
| elem_classes=["omada-task-selector", "omada-task-group-buttons"], | |
| ) | |
| submode_selector = gr.Radio( | |
| group_to_modes[default_group], | |
| value=default_mode, | |
| label="Task Mode", | |
| elem_classes=["omada-task-selector", "omada-task-mode-buttons"], | |
| ) | |
| with gr.Column(visible=True, elem_classes=["omada-mode-panel"]) as t2s_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Text-to-Speech Controls") | |
| with gr.Group(elem_classes=["omada-advanced"]): | |
| gr.Markdown("**Generation**") | |
| with gr.Row(): | |
| t2s_max_tokens = gr.Slider(2, 512, value=384, label="Speech token length", step=2) | |
| t2s_steps = gr.Slider(2, 512, value=128, label="Total refinement steps", step=2) | |
| with gr.Row(): | |
| t2s_block = gr.Slider(2, 512, value=128, label="Block length", step=2) | |
| t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, label="CFG scale", step=0.1) | |
| t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) | |
| with gr.Group(elem_classes=["omada-advanced"]): | |
| gr.Markdown("**Voice styling**") | |
| with gr.Row(): | |
| t2s_gender = gr.Dropdown(['random'] + app.genders, value='random', label="Voice gender") | |
| t2s_emotion = gr.Dropdown(['random'] + app.emotions, value='random', label="Emotion") | |
| with gr.Row(): | |
| t2s_speed = gr.Dropdown(['random'] + app.speeds, value='random', label="Speaking speed") | |
| t2s_pitch = gr.Dropdown(['random'] + app.pitches, value='random', label="Pitch") | |
| if T2S_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample prompts**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=T2S_EXAMPLES, | |
| inputs=[chat_input], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2s_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Speech-to-Speech Controls") | |
| s2s_audio = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"]) | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): | |
| s2s_max_tokens = gr.Slider(2, 512, value=512, label="Reply token length", step=2) | |
| with gr.Row(): | |
| s2s_steps = gr.Slider(2, 512, value=512, label="Refinement steps", step=2) | |
| s2s_block = gr.Slider(2, 512, value=512, label="Block length", step=2) | |
| with gr.Row(): | |
| s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, label="Sampling temperature", step=0.05) | |
| s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, label="CFG scale", step=0.1) | |
| if S2S_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample S2S clips**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=S2S_EXAMPLES, | |
| inputs=[s2s_audio], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2t_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Speech-to-Text Controls") | |
| s2t_audio = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"]) | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): | |
| with gr.Row(): | |
| s2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2) | |
| s2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2) | |
| s2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2) | |
| s2t_remasking = gr.Dropdown( | |
| choices=["low_confidence", "random"], | |
| value="low_confidence", | |
| label="Remasking strategy", | |
| ) | |
| if S2T_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample clips**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=S2T_EXAMPLES, | |
| inputs=[s2t_audio], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2t_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Video-to-Text Controls") | |
| v2t_video = gr.Video( | |
| label="Upload or record video", | |
| format=None, | |
| height=256, | |
| sources=["upload", "webcam"], | |
| ) | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): | |
| with gr.Row(): | |
| v2t_steps = gr.Slider(2, 512, value=256, label="Denoising steps", step=2) | |
| v2t_block = gr.Slider(2, 512, value=256, label="Block length", step=2) | |
| v2t_max_tokens = gr.Slider(2, 512, value=256, label="Max tokens", step=2) | |
| if V2T_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample videos**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=V2T_EXAMPLES, | |
| inputs=[v2t_video], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2s_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Video-to-Speech Controls") | |
| v2s_video = gr.Video( | |
| label="Upload or record video", | |
| format=None, | |
| height=256, | |
| sources=["upload", "webcam"], | |
| ) | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): | |
| v2s_max_tokens = gr.Slider(2, 512, value=256, label="Reply token length", step=2) | |
| with gr.Row(): | |
| v2s_steps = gr.Slider(2, 512, value=256, label="Refinement steps", step=2) | |
| v2s_block = gr.Slider(2, 512, value=256, label="Block length", step=2) | |
| with gr.Row(): | |
| v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) | |
| v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, label="CFG scale", step=0.1) | |
| if V2S_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample videos**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=V2S_EXAMPLES, | |
| inputs=[v2s_video], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as i2s_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Image-to-Speech Controls") | |
| i2s_image = gr.Image(type="pil", label="Upload image", sources=["upload"]) | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): | |
| i2s_max_tokens = gr.Slider(2, 512, value=256, label="Reply token length", step=2) | |
| with gr.Row(): | |
| i2s_steps = gr.Slider(2, 512, value=256, label="Refinement steps", step=2) | |
| i2s_block = gr.Slider(2, 512, value=256, label="Block length", step=2) | |
| with gr.Row(): | |
| i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) | |
| i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, label="CFG scale", step=0.1) | |
| if I2S_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample images**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=I2S_EXAMPLES, | |
| inputs=[i2s_image], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as image_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Image Tasks") | |
| image_mode_selector = gr.Radio( | |
| ["Generation", "Editing"], | |
| value="Generation", | |
| label="Sub-mode", | |
| ) | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"], visible=True) as t2i_settings: | |
| t2i_timesteps = gr.Slider(4, 128, value=max(4, min(128, app.t2i_default_timesteps)), label="Timesteps", step=2) | |
| t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) | |
| t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1) | |
| with gr.Accordion("Editing settings", open=True, elem_classes=["omada-advanced"], visible=False) as i2i_settings: | |
| i2i_image = gr.Image(type="pil", label="Reference image", sources=["upload"]) | |
| i2i_timesteps = gr.Slider(4, 128, value=max(4, min(128, app.i2i_default_timesteps)), label="Timesteps", step=2) | |
| i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) | |
| i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1) | |
| if T2I_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample prompts**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=T2I_EXAMPLES, | |
| inputs=[chat_input], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as chat_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Chat Controls") | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): | |
| chat_max_tokens = gr.Slider(2, 512, value=64, label="Reply max tokens", step=2) | |
| with gr.Row(): | |
| chat_steps = gr.Slider(2, 512, value=64, label="Refinement steps", step=2) | |
| chat_block = gr.Slider(2, 512, value=64, label="Block length", step=2) | |
| chat_temperature = gr.Slider(0.0, 2.0, value=0.8, label="Sampling temperature", step=0.05) | |
| if CHAT_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample prompts**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=CHAT_EXAMPLES, | |
| inputs=[chat_input], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as mmu_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Image Reasoning") | |
| mmu_image = gr.Image(type="pil", label="Image", sources=["upload"]) | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): | |
| mmu_max_tokens = gr.Slider(2, 512, value=256, label="Answer max tokens", step=2) | |
| with gr.Row(): | |
| mmu_steps = gr.Slider(2, 512, value=256, label="Refinement steps", step=2) | |
| mmu_block = gr.Slider(2, 512, value=128, label="Block length", step=2) | |
| mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, label="Sampling temperature", step=0.05) | |
| if MMU_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample prompts**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=MMU_EXAMPLES, | |
| inputs=[mmu_image, chat_input], | |
| examples_per_page=1, | |
| ) | |
| def _panel_updates(group: str, mode: str): | |
| show_t2s = group == "Speech" and mode == "Text → Speech" | |
| show_s2s = group == "Speech" and mode == "Speech → Speech" | |
| show_v2s = group == "Video" and mode == "Video → Speech" | |
| show_i2s = False | |
| show_s2t = group == "Speech" and mode == "Speech → Text" | |
| show_v2t = group == "Video" and mode == "Video → Text" | |
| show_chat = group == "Text" and mode == "Text" | |
| show_mmu = group == "Multi-Modal" and mode == "MMU (Image → Text)" | |
| show_image = group == "Image" and mode in ("Text → Image", "Image Editing") | |
| placeholder = placeholder_map.get(mode, chat_input.placeholder) | |
| image_mode_value = "Generation" if mode == "Text → Image" else "Editing" | |
| t2i_visible = show_image and mode == "Text → Image" | |
| i2i_visible = show_image and mode == "Image Editing" | |
| image_mode_update = gr.update(value=image_mode_value) if show_image else gr.update() | |
| return ( | |
| gr.update(placeholder=placeholder), | |
| gr.update(visible=show_t2s), | |
| gr.update(visible=show_s2s), | |
| gr.update(visible=show_v2s), | |
| gr.update(visible=show_i2s), | |
| gr.update(visible=show_s2t), | |
| gr.update(visible=show_v2t), | |
| gr.update(visible=show_chat), | |
| gr.update(visible=show_mmu), | |
| gr.update(visible=show_image), | |
| image_mode_update, | |
| gr.update(visible=t2i_visible), | |
| gr.update(visible=i2i_visible), | |
| ) | |
| def _on_group_change(group: str): | |
| default_mode_local = group_to_modes[group][0] | |
| submode_update = gr.update(choices=group_to_modes[group], value=default_mode_local) | |
| panel_updates = _panel_updates(group, default_mode_local) | |
| return (submode_update, *panel_updates) | |
| def _on_submode_change(mode: str, group: str): | |
| return _panel_updates(group, mode) | |
| task_group_selector.change( | |
| _on_group_change, | |
| inputs=[task_group_selector], | |
| outputs=[ | |
| submode_selector, | |
| chat_input, | |
| t2s_panel, | |
| s2s_panel, | |
| v2s_panel, | |
| i2s_panel, | |
| s2t_panel, | |
| v2t_panel, | |
| chat_panel, | |
| mmu_panel, | |
| image_panel, | |
| image_mode_selector, | |
| t2i_settings, | |
| i2i_settings, | |
| ], | |
| ) | |
| submode_selector.change( | |
| _on_submode_change, | |
| inputs=[submode_selector, task_group_selector], | |
| outputs=[ | |
| chat_input, | |
| t2s_panel, | |
| s2s_panel, | |
| v2s_panel, | |
| i2s_panel, | |
| s2t_panel, | |
| v2t_panel, | |
| chat_panel, | |
| mmu_panel, | |
| image_panel, | |
| image_mode_selector, | |
| t2i_settings, | |
| i2i_settings, | |
| ], | |
| ) | |
| def _toggle_image_task(task: str): | |
| return ( | |
| gr.update(visible=task == "Generation"), | |
| gr.update(visible=task == "Editing"), | |
| ) | |
| image_mode_selector.change( | |
| _toggle_image_task, | |
| inputs=[image_mode_selector], | |
| outputs=[t2i_settings, i2i_settings], | |
| ) | |
| def _chat_handler( | |
| history, | |
| message, | |
| group, | |
| mode, | |
| s2t_audio_path, | |
| v2t_video_path, | |
| t2s_max_tokens, | |
| t2s_steps, | |
| t2s_block, | |
| t2s_temperature, | |
| t2s_cfg, | |
| t2s_gender, | |
| t2s_emotion, | |
| t2s_speed, | |
| t2s_pitch, | |
| s2t_steps, | |
| s2t_block, | |
| s2t_max_tokens, | |
| s2t_remasking, | |
| v2t_steps, | |
| v2t_block, | |
| v2t_max_tokens, | |
| s2s_audio_path, | |
| s2s_max_tokens, | |
| s2s_steps, | |
| s2s_block, | |
| s2s_temperature, | |
| s2s_cfg, | |
| i2s_image, | |
| i2s_max_tokens, | |
| i2s_steps, | |
| i2s_block, | |
| i2s_temperature, | |
| i2s_cfg, | |
| image_mode, | |
| t2i_timesteps, | |
| t2i_temperature, | |
| t2i_guidance, | |
| i2i_image, | |
| i2i_timesteps, | |
| i2i_temperature, | |
| i2i_guidance, | |
| v2s_video_path, | |
| v2s_max_tokens, | |
| v2s_steps, | |
| v2s_block, | |
| v2s_temperature, | |
| v2s_cfg, | |
| chat_max_tokens, | |
| chat_steps, | |
| chat_block, | |
| chat_temperature, | |
| mmu_image, | |
| mmu_max_tokens, | |
| mmu_steps, | |
| mmu_block, | |
| mmu_temperature, | |
| ): | |
| history = history or [] | |
| message = (message or "").strip() | |
| response = "" | |
| if mode == "Text → Speech": | |
| if not message: | |
| status = "Please type some text for speech synthesis." | |
| response = _render_text_message(status, "") | |
| else: | |
| audio_result, status = app.run_t2s( | |
| message, | |
| t2s_max_tokens, | |
| t2s_steps, | |
| t2s_block, | |
| t2s_temperature, | |
| t2s_cfg, | |
| t2s_gender, | |
| t2s_emotion, | |
| t2s_speed, | |
| t2s_pitch, | |
| ) | |
| response = _render_audio_message(status, audio_result) | |
| display_user_raw = message or "[Speech request]" | |
| elif mode == "Speech → Speech": | |
| audio_result, status = app.run_s2s( | |
| s2s_audio_path, | |
| s2s_max_tokens, | |
| s2s_steps, | |
| s2s_block, | |
| s2s_temperature, | |
| s2s_cfg, | |
| ) | |
| response = _render_audio_message(status, audio_result) | |
| display_user_raw = message or "[Speech-to-speech request]" | |
| elif mode == "Video → Speech": | |
| audio_result, status = app.run_v2s( | |
| v2s_video_path, | |
| message, | |
| v2s_max_tokens, | |
| v2s_steps, | |
| v2s_block, | |
| v2s_temperature, | |
| v2s_cfg, | |
| ) | |
| response = _render_audio_message(status, audio_result) | |
| display_user_raw = message or "[Video-to-speech request]" | |
| elif mode == "Speech → Text": | |
| if not s2t_audio_path: | |
| status = "Please upload or record an audio clip first." | |
| response = _render_text_message(status, "") | |
| else: | |
| transcript, status = app.run_s2t( | |
| s2t_audio_path, | |
| s2t_steps, | |
| s2t_block, | |
| s2t_max_tokens, | |
| s2t_remasking, | |
| ) | |
| response = _render_text_message(status, transcript) | |
| display_user_raw = message or "[Audio transcription request]" | |
| elif mode == "Video → Text": | |
| if not v2t_video_path: | |
| status = "Please upload or record a video first." | |
| response = _render_text_message(status, "") | |
| else: | |
| caption, status = app.run_v2t( | |
| v2t_video_path, | |
| v2t_steps, | |
| v2t_block, | |
| v2t_max_tokens, | |
| ) | |
| response = _render_text_message(status, caption) | |
| display_user_raw = message or "[Video caption request]" | |
| elif mode == "Text": | |
| reply, status = app.run_chat( | |
| message, | |
| chat_max_tokens, | |
| chat_steps, | |
| chat_block, | |
| chat_temperature, | |
| ) | |
| response = _render_text_message(status, reply) | |
| display_user_raw = message or "[Text request]" | |
| elif mode == "MMU (Image → Text)": | |
| reply, status = app.run_mmu( | |
| [mmu_image] if mmu_image is not None else [], | |
| message, | |
| mmu_max_tokens, | |
| mmu_steps, | |
| mmu_block, | |
| mmu_temperature, | |
| ) | |
| response = _render_text_message(status, reply) | |
| display_user_raw = message or "[Multi-image question]" | |
| elif mode == "Text → Image": | |
| if not message: | |
| status = "Please provide a prompt for image generation." | |
| response = _render_text_message(status, "") | |
| else: | |
| image_result, status = app.run_t2i( | |
| message, | |
| t2i_timesteps, | |
| t2i_temperature, | |
| t2i_guidance, | |
| ) | |
| response = _render_image_message(status, image_result) | |
| display_user_raw = message or "[Image generation request]" | |
| elif mode == "Image Editing": | |
| image_result, status = app.run_i2i( | |
| message, | |
| i2i_image, | |
| i2i_timesteps, | |
| i2i_temperature, | |
| i2i_guidance, | |
| ) | |
| response = _render_image_message(status, image_result) | |
| display_user_raw = message or "[Image editing request]" | |
| else: | |
| status = f"Mode '{mode}' is not supported." | |
| response = _render_text_message(status, "") | |
| display_user_raw = message or "[Unsupported mode]" | |
| display_user = _format_user_message(display_user_raw) | |
| history = history + [(display_user, response)] | |
| return history, "" | |
| submit_inputs = [ | |
| chatbox, | |
| chat_input, | |
| task_group_selector, | |
| submode_selector, | |
| s2t_audio, | |
| v2t_video, | |
| t2s_max_tokens, | |
| t2s_steps, | |
| t2s_block, | |
| t2s_temperature, | |
| t2s_cfg, | |
| t2s_gender, | |
| t2s_emotion, | |
| t2s_speed, | |
| t2s_pitch, | |
| s2t_steps, | |
| s2t_block, | |
| s2t_max_tokens, | |
| s2t_remasking, | |
| v2t_steps, | |
| v2t_block, | |
| v2t_max_tokens, | |
| s2s_audio, | |
| s2s_max_tokens, | |
| s2s_steps, | |
| s2s_block, | |
| s2s_temperature, | |
| s2s_cfg, | |
| i2s_image, | |
| i2s_max_tokens, | |
| i2s_steps, | |
| i2s_block, | |
| i2s_temperature, | |
| i2s_cfg, | |
| image_mode_selector, | |
| t2i_timesteps, | |
| t2i_temperature, | |
| t2i_guidance, | |
| i2i_image, | |
| i2i_timesteps, | |
| i2i_temperature, | |
| i2i_guidance, | |
| v2s_video, | |
| v2s_max_tokens, | |
| v2s_steps, | |
| v2s_block, | |
| v2s_temperature, | |
| v2s_cfg, | |
| chat_max_tokens, | |
| chat_steps, | |
| chat_block, | |
| chat_temperature, | |
| mmu_image, | |
| mmu_max_tokens, | |
| mmu_steps, | |
| mmu_block, | |
| mmu_temperature, | |
| ] | |
| submit_outputs = [chatbox, chat_input] | |
| chat_input.submit(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) | |
| send_button.click(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) | |
| def _clear_session(): | |
| return ( | |
| [], | |
| "", | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| ) | |
| clear_button.click( | |
| _clear_session, | |
| inputs=None, | |
| outputs=[ | |
| chatbox, | |
| chat_input, | |
| s2t_audio, | |
| v2t_video, | |
| s2s_audio, | |
| i2s_image, | |
| v2s_video, | |
| i2i_image, | |
| mmu_image, | |
| ], | |
| ) | |
| demo.launch( | |
| share=share, | |
| server_name=server_name, | |
| server_port=server_port, | |
| ) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="OMada Gradio demo for audio/video tasks") | |
| parser.add_argument("--train-config", required=True, help="Path to the training YAML used to build tokenizer + VQ modules") | |
| parser.add_argument("--checkpoint", required=True, help="Path to an `unwrapped_model` directory") | |
| parser.add_argument("--device", default=None, help="Override device (e.g. cuda:0). Defaults to CUDA if available") | |
| parser.add_argument("--share", action="store_true", help="Enable public Gradio share link") | |
| parser.add_argument("--server-name", default="0.0.0.0", help="Host address for Blocks.launch") | |
| parser.add_argument("--server-port", type=int, default=None, help="Port for Blocks.launch") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| app = OmadaDemo(args.train_config, args.checkpoint, args.device) | |
| build_demo(app, args.share, args.server_name, args.server_port) | |
| if __name__ == "__main__": | |
| main() | |