AIDAS-Omni-Modal-Diffusion / MMaDA /inference /gradio_multimodal_demo_inst.py
3v324v23's picture
..
e7c040d
#!/usr/bin/env python3
"""
Gradio demo for OMada Stage 1.3 checkpoints covering:
* Text-to-Speech (T2S)
* Speech-to-Text (S2T)
* Video-to-Text (V2T)
The implementation wraps the existing CLI inference helpers so that a single
checkpoint directory (…/checkpoint-XXXX/unwrapped_model) can be previewed
interactively. Usage:
python MMaDA/inference/gradio_multimodal_demo_inst.py --train-config /t1data/users/snu-lab-d/omada/OMaDA/MMaDA/inference/demo/demo.yaml --checkpoint ../ckpt/checkpoint-400000/unwrapped_model/ --server-port 7860
If you need remote access, pass `--share` which simply forwards the flag to
`gradio.Blocks.launch`. For more reliable sharing, run the demo locally without
`--share` and tunnel the chosen port with a tool such as `ngrok http 7860` or
`cloudflared tunnel --url http://localhost:7860` instead of relying on Gradio’s
temporary share links.
"""
import argparse
import base64
import html
import io
import os
import math
import random
import sys
import tempfile
import wave
from pathlib import Path
import shutil
import time
from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch.nn.functional as F
from PIL import Image
CUSTOM_CSS = """
@import url('https://cdn.jsdelivr.net/gh/orioncactus/[email protected]/dist/web/variable/pretendardvariable-dynamic-subset.css');
:root {
--omada-font-ui: "Pretendard Variable", -apple-system, BlinkMacSystemFont,
"Segoe UI", system-ui, sans-serif;
}
/* 전체 기본 폰트 */
html, body, .gradio-container {
font-family: var(--omada-font-ui) !important;
}
/* Gradio 컴포넌트들까지 강제로 통일 */
.gradio-container * {
font-family: var(--omada-font-ui) !important;
}
/* 혹시 일부 컴포넌트가 자체 폰트 지정해버리는 경우를 위한 안전핀 */
.gradio-textbox textarea,
.gradio-dropdown select,
.gradio-radio label,
.gradio-button button,
.gradio-slider label,
.gradio-accordion *,
.gradio-chatbot * {
font-family: var(--omada-font-ui) !important;
}
:root {
--omada-primary: #1e3a8a;
--omada-accent: #1d4ed8;
--omada-surface: #f3f4f6;
--omada-surface-alt: #ffffff;
--omada-border: #d0d7e5;
--omada-text-primary: #111827;
--omada-text-muted: #374151;
color-scheme: light;
}
html, body, body.dark, html.dark {
background: var(--omada-surface) !important;
color: var(--omada-text-primary) !important;
}
.gradio-container {
background: var(--omada-surface);
color: var(--omada-text-primary);
}
.omada-page-heading {
margin-bottom: 0;
color: var(--omada-text-primary);
}
.omada-tab-intro p {
font-size: 0.95rem;
color: var(--omada-text-primary);
margin-top: 0;
opacity: 0.9;
}
.omada-card {
background: var(--omada-surface-alt);
border-radius: 16px !important;
padding: 18px !important;
box-shadow: none;
border: 1px solid var(--omada-border);
color: var(--omada-text-primary);
}
.omada-card .gradio-slider .wrap-inner {
gap: 6px;
}
.omada-card .gradio-slider input[type=range]::-webkit-slider-thumb {
background: var(--omada-primary);
}
.gradio-slider input[type=range]::-webkit-slider-runnable-track {
background: rgba(14, 33, 80, 0.2);
}
.omada-section-title p {
text-transform: uppercase;
font-size: 0.78rem;
letter-spacing: 0.14em;
color: rgba(30, 58, 138, 0.85);
margin: 0 0 12px 0;
}
.omada-output .gradio-audio, .omada-output .gradio-textbox {
margin-top: 12px;
}
.gradio-textbox, .gradio-dropdown, .gradio-slider, .gradio-audio {
color: var(--omada-text-primary) !important;
}
.gradio-dropdown .wrap .label, .gradio-textbox label, .gradio-slider label {
color: var(--omada-text-primary);
}
.gradio-dropdown .single-select, .gradio-textbox textarea {
background: #ffffff !important;
border: 1px solid var(--omada-border) !important;
color: var(--omada-text-primary) !important;
}
.gradio-dropdown .single-select select, .gradio-textbox textarea {
color: var(--omada-text-primary) !important;
}
.gradio-textbox textarea::placeholder {
color: rgba(148, 163, 184, 0.65);
}
.gradio-dropdown, .gradio-textbox, .gradio-audio, .gradio-video, .gradio-slider {
background: #ffffff !important;
border: 1px solid var(--omada-border) !important;
border-radius: 12px !important;
}
.full-width-button button {
width: 100%;
background: var(--omada-primary) !important;
color: white !important;
border: none !important;
font-weight: 600;
transition: transform 0.2s ease, box-shadow 0.2s ease;
box-shadow: 0 12px 30px -12px rgba(79, 70, 229, 0.65);
}
.full-width-button button:hover {
transform: translateY(-1px);
box-shadow: 0 18px 34px -14px rgba(79, 70, 229, 0.75);
}
.omada-advanced .gr-accordion-header {
font-size: 0.85rem;
letter-spacing: 0.05em;
color: var(--omada-text-muted);
}
.omada-advanced .gr-accordion {
border: 1px solid var(--omada-border);
border-radius: 12px;
background: #ffffff;
}
.gradio-tabs {
background: transparent;
}
.gradio-tabs ul.tab-list {
background: transparent;
border-bottom: 1px solid var(--omada-border);
}
.gradio-tabs button {
color: var(--omada-text-primary);
}
.gradio-tabs button.selected {
color: var(--omada-text-primary);
background: rgba(14, 33, 80, 0.1);
border-bottom: 2px solid var(--omada-primary);
}
.gradio-container .label {
background: rgba(30, 58, 138, 0.1) !important;
color: var(--omada-primary) !important;
border: 1px solid rgba(30, 58, 138, 0.25) !important;
border-radius: 999px !important;
padding: 4px 12px !important;
}
.gradio-button.primary {
background: var(--omada-primary) !important;
color: #ffffff !important;
border: 1px solid var(--omada-primary) !important;
}
.gradio-accordion {
box-shadow: none;
}
.omada-layout {
gap: 20px !important;
}
.omada-chat-column {
gap: 12px !important;
}
.omada-chat-column .gradio-chatbot {
border-radius: 16px;
box-shadow: none;
border: 1px solid var(--omada-border);
background: #ffffff;
}
.omada-controls {
gap: 16px !important;
}
.omada-mode-panel {
display: flex;
flex-direction: column;
gap: 16px !important;
}
.omada-examples-card {
padding-top: 10px !important;
}
.omada-output-panel .gradio-audio,
.omada-output-panel .gradio-textbox {
margin-top: 8px;
}
.omada-response-container {
display: flex;
flex-direction: column;
gap: 10px;
}
.omada-response-status {
margin: 0;
font-weight: 600;
font-size: 0.95rem;
color: var(--omada-text-primary);
}
.omada-response-block,
.omada-audio-block {
background: rgba(30, 58, 138, 0.05);
border-radius: 12px;
padding: 12px 14px;
color: var(--omada-text-primary);
white-space: pre-wrap;
word-break: break-word;
}
.omada-audio-block audio {
width: 100%;
}
.omada-header {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 18px !important;
margin-bottom: 18px;
text-align: center;
}
.omada-header .gradio-image {
background: transparent !important;
border: none !important;
}
.omada-header img {
object-fit: contain;
display: block;
}
.omada-logo {
max-width: 180px;
padding: 0 !important;
}
.omada-examples {
margin-top: 8px;
padding-top: 4px;
}
.omada-examples .gradio-dataset {
width: 100% !important;
}
.omada-examples .samples-table {
width: 100% !important;
}
.omada-examples table {
width: 100% !important;
}
.omada-examples td {
width: 100% !important;
}
.omada-examples .sample-button,
.omada-examples button {
width: 100% !important;
white-space: pre-wrap !important;
word-wrap: break-word !important;
height: auto !important;
min-height: 40px !important;
text-align: left !important;
padding: 12px 16px !important;
line-height: 1.5 !important;
overflow: visible !important;
text-overflow: clip !important;
display: block !important;
}
.omada-examples button span {
white-space: pre-wrap !important;
word-wrap: break-word !important;
overflow: visible !important;
text-overflow: clip !important;
display: block !important;
width: 100% !important;
}
.omada-logo .gradio-image,
.omada-logo .gradio-image > div,
.omada-logo .gradio-image .container {
background: transparent !important;
border: none !important;
box-shadow: none !important;
padding: 0 !important;
display: flex;
justify-content: center;
align-items: center;
}
.omada-logo img {
width: 100%;
height: auto;
}
.omada-logo button {
display: none !important;
}
.gradio-container .gradio-component,
.gradio-container .gradio-panel,
.gradio-container .gradio-box {
background: transparent !important;
color: var(--omada-text-primary);
}
.dark .gradio-container,
.dark .gradio-interface,
.dark .gradio-container * {
background-color: inherit !important;
color: var(--omada-text-primary) !important;
}
.dark .gradio-container .gradio-chatbot,
.dark .gradio-container .gradio-dropdown,
.dark .gradio-container .gradio-textbox,
.dark .gradio-container .gradio-audio,
.dark .gradio-container .gradio-video,
.dark .gradio-container .gradio-slider,
.dark .gradio-container .gradio-accordion,
.dark .gradio-container .gradio-panel,
.dark .gradio-container .gradio-box {
background: #ffffff !important;
border-color: var(--omada-border) !important;
}
.omada-title h2 {
font-size: 2.4rem;
font-weight: 700;
color: var(--omada-text-primary);
margin: 0;
}
.omada-title h3 {
font-size: 1.25rem;
font-weight: 600;
letter-spacing: 0.1em;
text-transform: uppercase;
color: var(--omada-text-muted);
margin: 6px 0 0;
}
.omada-tagline p {
color: var(--omada-text-primary);
font-size: 1rem;
margin: 0;
opacity: 0.9;
line-height: 1.45;
}
.omada-tagline .tagline-speech { color: #f97316; font-weight: 600; }
.omada-tagline .tagline-audio { color: #ec4899; font-weight: 600; }
.omada-tagline .tagline-video { color: #0ea5e9; font-weight: 600; }
.omada-tagline .tagline-text { color: #a855f7; font-weight: 600; }
.omada-tagline .tagline-image { color: #22c55e; font-weight: 600; }
.gradio-container .prose :where(h1, h2, h3, h4, h5, h6) {
color: var(--omada-text-primary) !important;
}
.gradio-container .prose :where(p, li) {
color: var(--omada-text-muted) !important;
}
.gradio-container label, .gradio-container span, .gradio-container button {
color: var(--omada-text-primary);
}
.gradio-container .dark {
background: #ffffff !important;
color: var(--omada-text-primary) !important;
}
.omada-logo-img {
max-width: 250px;
width: 100%;
height: auto;
display: block;
margin: 0 auto;
}
.omada-logo-wrapper {
display: flex;
justify-content: center;
align-items: center;
}
"""
FORCE_LIGHT_MODE_JS = """
function() {
document.body.classList.remove('dark');
document.documentElement.classList.remove('dark');
const observer = new MutationObserver(function(mutations) {
document.body.classList.remove('dark');
document.documentElement.classList.remove('dark');
});
observer.observe(document.body, { attributes: true, attributeFilter: ['class'] });
observer.observe(document.documentElement, { attributes: true, attributeFilter: ['class'] });
}
"""
DEMO_ROOT = Path(__file__).resolve().parent / "demo"
LOGO_PATH = DEMO_ROOT / "logo.png"
T2S_TEXT_PATH = DEMO_ROOT / "t2s" / "text.txt"
CHAT_TEXT_PATH = DEMO_ROOT / "chat" / "text.txt"
T2I_TEXT_PATH = DEMO_ROOT / "t2i" / "text.txt"
def _load_logo_data() -> Optional[str]:
if not LOGO_PATH.exists():
return None
try:
import base64
except ImportError:
return str(LOGO_PATH)
try:
encoded = base64.b64encode(LOGO_PATH.read_bytes()).decode("utf-8")
except OSError:
return str(LOGO_PATH)
return f"data:image/png;base64,{encoded}"
def _load_t2s_examples():
if not T2S_TEXT_PATH.exists():
return []
lines = [
line.strip()
for line in T2S_TEXT_PATH.read_text(encoding="utf-8").splitlines()
if line.strip()
]
return [[line] for line in lines]
def _load_chat_examples():
if not CHAT_TEXT_PATH.exists():
return []
lines = [
line.strip()
for line in CHAT_TEXT_PATH.read_text(encoding="utf-8").splitlines()
if line.strip()
]
return [[line] for line in lines]
def _load_t2i_examples():
if not T2I_TEXT_PATH.exists():
return []
lines = [
line.strip()
for line in T2I_TEXT_PATH.read_text(encoding="utf-8").splitlines()
if line.strip()
]
return [[line] for line in lines]
def _load_i2i_examples():
"""demo/i2i 안의 image*.jpg + text*.txt 쌍을 Examples로 묶어줌."""
d = DEMO_ROOT / "i2i"
if not d.exists():
return []
# 이미지 파일들 (image1.jpeg, image2.png, ...)
image_files = sorted(
[p for p in d.iterdir() if p.is_file() and p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}]
)
# 텍스트 파일들 (text1.txt, text2.txt, ...)
text_files = sorted(
[p for p in d.iterdir() if p.is_file() and p.suffix.lower() == ".txt"]
)
n = min(len(image_files), len(text_files))
examples = []
for i in range(n):
img_path = image_files[i]
txt_path = text_files[i]
instruction = txt_path.read_text(encoding="utf-8").strip()
if not instruction:
continue
# Gradio Examples 형식: [image, instruction_text]
examples.append([str(img_path), instruction])
return examples
def _load_media_examples(subdir: str, suffixes):
target_dir = DEMO_ROOT / subdir
if not target_dir.exists():
return []
examples = []
for path in sorted(target_dir.iterdir()):
if path.is_file() and path.suffix.lower() in suffixes:
examples.append([str(path)])
return examples
T2S_EXAMPLES = _load_t2s_examples()
CHAT_EXAMPLES = _load_chat_examples()
T2I_EXAMPLES = _load_t2i_examples()
I2I_EXAMPLES = _load_i2i_examples()
S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"})
if not S2S_EXAMPLES:
S2S_EXAMPLES = S2T_EXAMPLES[: min(4, len(S2T_EXAMPLES))]
V2S_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
if not V2S_EXAMPLES:
V2S_EXAMPLES = V2T_EXAMPLES[: min(4, len(V2T_EXAMPLES))]
I2S_EXAMPLES = _load_media_examples("i2s", {".png", ".jpg", ".jpeg", ".webp"})
LOGO_DATA_URI = _load_logo_data()
MMU_IMAGE = DEMO_ROOT / "mmu" / "1.jpg"
# MMU_IMAGE_ALT = DEMO_ROOT / "mmu" / "SD_IMG_00235_1.png"
if MMU_IMAGE.exists():
MMU_EXAMPLES = [
[
str(MMU_IMAGE),
"Describe the scene in this image in detail.",
]
]
else:
MMU_EXAMPLES = []
if not I2S_EXAMPLES and MMU_EXAMPLES:
I2S_EXAMPLES = [[example[0]] for example in MMU_EXAMPLES]
def _render_response(status: str, body_html: str = "") -> str:
safe_status = html.escape(status or "")
parts = []
if safe_status:
parts.append(f"<p class='omada-response-status'>{safe_status}</p>")
if body_html:
parts.append(body_html)
content = "".join(parts)
return f"<div class='omada-response-container'>{content}</div>"
def _render_text_message(status: str, content: Optional[str]) -> str:
# post-processing
if content:
remove_tokens = ["<|eot_id|>", "<|eot_id>", "<eot_id|>", "<eot_id>", "<eot>", "assistant"]
for token in remove_tokens:
content = content.replace(token, "")
content = (content or "").strip()
if not content:
return _render_response(status)
safe_content = html.escape(content).replace("\n", "<br>")
body = f"<div class='omada-response-block'>{safe_content}</div>"
return _render_response(status, body)
def _render_audio_message(status: str, audio: Optional[Tuple[int, np.ndarray]]) -> str:
"""Render an inline HTML audio player for chat responses."""
if not audio:
return _render_response(status)
sample_rate, data = audio
if data is None:
return _render_response(status)
waveform = np.asarray(data, dtype=np.float32)
if waveform.size == 0:
return _render_response(status)
if waveform.ndim == 1:
waveform = waveform[:, None]
channels = waveform.shape[1]
clipped = np.clip(waveform, -1.0, 1.0)
pcm16 = (clipped * 32767.0).astype(np.int16)
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_writer:
wav_writer.setnchannels(channels)
wav_writer.setsampwidth(2) # 16-bit PCM
wav_writer.setframerate(int(sample_rate))
wav_writer.writeframes(pcm16.tobytes())
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
audio_tag = (
"<div class='omada-audio-block'>"
"<audio controls preload='auto' playsinline>"
f"<source src='data:audio/wav;base64,{encoded}' type='audio/wav' /></audio>"
"</div>"
)
return _render_response(status, audio_tag)
def _render_image_message(status: str, image: Optional[Image.Image]) -> str:
if image is None:
return _render_response(status)
buffer = io.BytesIO()
try:
image.save(buffer, format="PNG")
except Exception:
return _render_response(status)
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
image_html = (
"<div class='omada-response-block'>"
"<img src='data:image/png;base64,"
f"{encoded}"
"' alt='Generated image' style='max-width:100%;border-radius:12px;' />"
"</div>"
)
return _render_response(status, image_html)
def _render_image_text_message(status: str, image: Optional[Image.Image], text: str) -> str:
"""Render combined text + image output for TI2TI."""
blocks = []
text_clean = (text or "").strip()
if text_clean:
safe_text = html.escape(text_clean).replace("\n", "<br>")
blocks.append(f"<div class='omada-response-block'>{safe_text}</div>")
if image is not None:
buffer = io.BytesIO()
try:
image.save(buffer, format="PNG")
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
blocks.append(
"<div class='omada-response-block'>"
"<img src='data:image/png;base64,"
f"{encoded}"
"' alt='Generated image' style='max-width:100%;border-radius:12px;' />"
"</div>"
)
except Exception:
pass
body = "".join(blocks)
return _render_response(status, body if body else None)
def _format_user_message(message: str) -> str:
clean = html.escape(message or "")
return clean.replace("\n", "<br>")
# Ensure project modules (models, training, inference.common, …) are importable when the
# script is launched directly via `python MMaDA/inference/gradio_multimodal_demo.py`.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
import cv2
import gradio as gr
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from inference.common import (
build_uni_prompting,
get_vq_model_audio,
get_vq_model_image,
load_omada_from_checkpoint,
load_train_config,
)
from models import get_mask_schedule
from models.modeling_omada import add_gumbel_noise, get_num_transfer_tokens
from training.data import S2T_INSTRUCTION, T2S_INSTRUCTION, V2T_INSTRUCTION, V2S_INSTRUCTION
from training.utils import image_transform
def _cfg_get(cfg, key, default=None):
"""Lightweight helper to read DictConfig/dict/objects safely."""
if cfg is None:
return default
if isinstance(cfg, dict):
return cfg.get(key, default)
try:
value = getattr(cfg, key)
except AttributeError:
value = None
if value is not None:
return value
getter = getattr(cfg, "get", None)
if callable(getter):
try:
value = getter(key)
except TypeError:
try:
value = getter(key, default)
except Exception:
value = None
if value is not None:
return value
return default
def _resolve_noise_schedule(train_cfg) -> callable:
"""Return the diffusion noise schedule used for T2S sampling."""
schedule_cfg = getattr(train_cfg, "mask_schedule", None)
if schedule_cfg and hasattr(schedule_cfg, "schedule"):
schedule_name = schedule_cfg.schedule
schedule_kwargs = schedule_cfg.get("params", {})
return get_mask_schedule(schedule_name, **schedule_kwargs)
schedule_name = train_cfg.training.get("mask_schedule", "cosine")
return get_mask_schedule(schedule_name)
def _resolve_mask_schedule(cfg) -> callable:
"""Mirror training-time mask schedule resolution for image tasks."""
schedule_cfg = getattr(cfg, "mask_schedule", None)
if isinstance(schedule_cfg, DictConfig):
schedule_name = getattr(schedule_cfg, "schedule", None)
params_cfg = getattr(schedule_cfg, "params", None)
elif isinstance(schedule_cfg, dict):
schedule_name = schedule_cfg.get("schedule")
params_cfg = schedule_cfg.get("params")
else:
schedule_name = None
params_cfg = None
training_cfg = getattr(cfg, "training", None)
if schedule_name is None:
if training_cfg is None:
schedule_name = "cosine"
else:
schedule_name = getattr(training_cfg, "mask_schedule", None)
if schedule_name is None and isinstance(training_cfg, dict):
schedule_name = training_cfg.get("mask_schedule")
if schedule_name is None:
schedule_name = "cosine"
params = {}
if params_cfg is not None:
if isinstance(params_cfg, DictConfig):
params = OmegaConf.to_container(params_cfg, resolve=True) or {}
elif isinstance(params_cfg, dict):
params = dict(params_cfg)
else:
params = params_cfg
if not isinstance(params, dict):
params = {}
return get_mask_schedule(schedule_name, **params)
class OmadaDemo:
"""Lightweight container that loads all inference assets once."""
def __init__(self, train_config: str, checkpoint: str, device: Optional[str] = None):
ckpt_path = Path(checkpoint)
if ckpt_path.name != "unwrapped_model":
raise ValueError(
"`--checkpoint` must point to an `unwrapped_model` directory. "
f"Received: {checkpoint}"
)
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.train_cfg = load_train_config(train_config)
self.uni_prompting, _ = build_uni_prompting(self.train_cfg)
# Core models
self.model = load_omada_from_checkpoint(str(ckpt_path), self.device)
self.vq_audio = get_vq_model_audio(self.train_cfg, self.device)
self.vq_image = get_vq_model_image(self.train_cfg, self.device)
self.model.eval()
self.vq_audio.eval()
self.vq_image.eval()
# Cached constants
self.mask_token_id = int(self.model.config.mask_token_id)
self.noise_schedule = _resolve_noise_schedule(self.train_cfg)
self.sample_rate = int(getattr(self.vq_audio.u2s_config.data, "sampling_rate", 22050))
dataset_cfg = getattr(self.train_cfg, "dataset", None)
dataset_params = getattr(dataset_cfg, "params", None) if dataset_cfg else None
preprocess_cfg = getattr(dataset_cfg, "preprocessing", None) if dataset_cfg else None
self.image_resolution = int(_cfg_get(dataset_params, "resolution", 336))
self.t2i_resolution = int(_cfg_get(dataset_params, "t2i_resolution", self.image_resolution))
mmu_cfg = _cfg_get(dataset_params, "mmu_interleaved")
self.mmu_image_resolution = int(_cfg_get(mmu_cfg, "resolution", self.image_resolution))
self.video_resolution = 224
patch_size = int(_cfg_get(dataset_params, "vq_patch_size", 16))
self.image_seq_len = int(
_cfg_get(
dataset_params,
"i2i_seq_len",
max(1, (self.image_resolution // patch_size) ** 2),
)
)
self.t2i_seq_len = int(
_cfg_get(
dataset_params,
"t2i_seq_len",
max(1, (self.t2i_resolution // patch_size) ** 2),
)
)
speech_cfg = _cfg_get(dataset_params, "video_speech_dataset")
self.num_frames_v2t = max(1, int(_cfg_get(speech_cfg, "num_frames_v2t", 5)))
self.num_frames_v2s = max(1, int(_cfg_get(speech_cfg, "num_frames_v2s", 4)))
self.genders = ['female', 'male']
self.emotions = ['angry', 'happy', 'neutral', 'sad']
self.speeds = ['normal', 'fast', 'slow']
self.pitches = ['normal', 'high', 'low']
# Pre-computed offsets reused across calls
self.text_vocab_size = len(self.uni_prompting.text_tokenizer)
# The current checkpoints assume a fixed 8k vision codebook and 4k speech codebook.
self.codebook_size = 8192
self.speech_codebook = self.codebook_size
self.audio_codebook_size = 4096
self.max_audio_len_short = int(
getattr(
self.uni_prompting,
"max_audio_len_short",
getattr(self.train_cfg.dataset.preprocessing, "max_aud_length_short", 256),
)
)
self.max_text_len = int(getattr(self.train_cfg.dataset.preprocessing, "max_seq_length", 1024))
model_seq_len = getattr(self.model.config, "num_vq_tokens", None)
if model_seq_len is None:
model_seq_len = getattr(getattr(self.train_cfg.model, "omada", None), "num_vq_tokens", None)
if model_seq_len is None:
model_seq_len = getattr(getattr(self.train_cfg.model, "vq_model_image", None), "num_vq_tokens", None)
if model_seq_len is not None:
self.image_seq_len = min(self.image_seq_len, int(model_seq_len))
self.t2i_seq_len = min(self.t2i_seq_len, int(model_seq_len))
print(self.image_seq_len)
self.image_noise_schedule = _resolve_noise_schedule(self.train_cfg)
self.mask_schedule = _resolve_mask_schedule(self.train_cfg)
training_cfg = getattr(self.train_cfg, "training", None)
self.noise_type = _cfg_get(training_cfg, "noise_type", "mask")
self.predict_all_tokens = bool(_cfg_get(training_cfg, "predict_all_tokens", False))
self.t2i_default_timesteps = int(_cfg_get(training_cfg, "generation_timesteps", 20))
self.i2i_default_timesteps = int(_cfg_get(training_cfg, "i2i_eval_timesteps", 24))
self.audio_condition_default = "gender-female_emotion-neutral_speed-normal_pitch-normal"
style_map = getattr(getattr(self.vq_audio, "config", None), "u2s_style2idx", None)
if isinstance(style_map, dict):
self._valid_conditions = set(style_map.keys())
if self._valid_conditions and self.audio_condition_default not in self._valid_conditions:
# Ensure the default condition is valid for the tokenizer.
self.audio_condition_default = next(iter(self._valid_conditions))
else:
self._valid_conditions = set()
self._temp_video_files = []
# ------------------------------------------------------------------
# Text-to-Speech
# ------------------------------------------------------------------
def run_t2s(
self,
text: str,
max_new_tokens: int,
steps: int,
block_length: int,
temperature: float,
cfg_scale: float,
gender_choice: str,
emotion_choice: str,
speed_choice: str,
pitch_choice: str,
) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
if text is None or not text.strip():
return None, "Please provide text to synthesize."
speech_len, steps, block_length = self._prepare_block_schedule(
max_new_tokens,
steps,
block_length,
)
gender = self._resolve_choice(gender_choice, self.genders)
emotion = self._resolve_choice(emotion_choice, self.emotions)
speed = self._resolve_choice(speed_choice, self.speeds)
pitch = self._resolve_choice(pitch_choice, self.pitches)
text = text.strip().upper()
prompt = (
"<|start_header_id|>user<|end_header_id|>\n"
f"{random.choice(T2S_INSTRUCTION)}\n{text}"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
audio_tokens = torch.full(
(1, speech_len),
fill_value=self.mask_token_id,
dtype=torch.long,
device=self.device,
)
input_ids, attention_mask = self.uni_prompting(([prompt], audio_tokens), "t2s_gen")
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
with torch.no_grad():
outputs = self.model.t2s_generate_mmu_like(
input_ids=input_ids,
max_new_tokens=int(speech_len),
steps=int(steps),
block_length=int(block_length),
temperature=float(temperature),
cfg_scale=float(cfg_scale),
mask_token_id=self.mask_token_id,
attention_mask=attention_mask,
uni_prompting=self.uni_prompting,
codebook_size=self.codebook_size,
)
if not outputs:
return None, "Generation produced no speech tokens."
rel = outputs[0]
if isinstance(rel, torch.Tensor):
rel_ids = rel.detach().cpu().tolist()
else:
rel_ids = list(rel)
if not rel_ids:
return None, "Generation produced no speech tokens."
speech_units = "".join(f"<|speech_{sid}|>" for sid in rel_ids)
condition = f"gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}"
wav = self.vq_audio.decode(
speech_units,
condition=condition,
output_wav_file=os.path.join("/tmp", "omada_t2s.wav"),
)
audio = (self.sample_rate, wav.astype(np.float32))
status = f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})."
return audio, status
# ------------------------------------------------------------------
# Speech-to-Speech
# ------------------------------------------------------------------
def run_s2s(
self,
audio_path: Optional[str],
max_new_tokens: int,
steps: int,
block_length: int,
temperature: float,
cfg_scale: float,
) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
if not audio_path:
return None, "Please upload source speech first."
try:
user_tokens = self.vq_audio.encode(audio_path)
except Exception as exc:
return None, f"Failed to encode input audio: {exc}"
if not isinstance(user_tokens, torch.Tensor):
user_tokens = torch.tensor(user_tokens)
if user_tokens.dim() == 1:
user_tokens = user_tokens.unsqueeze(0)
user_tokens = user_tokens.to(self.device, dtype=torch.long)
if user_tokens.numel() == 0:
return None, "Uploaded speech clip produced no tokens."
gen_len = max(1, int(max_new_tokens))
gen_len = min(gen_len, self.max_audio_len_short)
gen_len, steps, block_length = self._prepare_block_schedule(
gen_len,
steps,
block_length,
)
offset = self.text_vocab_size + self.codebook_size
user_shifted = user_tokens + offset
assistant_placeholder = torch.full(
(1, gen_len),
self.mask_token_id,
dtype=torch.long,
device=self.device,
)
input_ids, attention_mask = self.uni_prompting(
([user_shifted], [assistant_placeholder]),
"s2s_gen",
)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
with torch.no_grad():
outputs = self.model.t2s_generate_mmu_like(
input_ids=input_ids,
max_new_tokens=gen_len,
steps=int(steps),
block_length=int(block_length),
temperature=float(temperature),
cfg_scale=float(cfg_scale),
mask_token_id=self.mask_token_id,
attention_mask=attention_mask,
uni_prompting=self.uni_prompting,
codebook_size=self.codebook_size,
audio_codebook_size=self.audio_codebook_size,
)
if not outputs:
return None, "Generation returned no tokens."
generated = outputs[0]
if isinstance(generated, torch.Tensor):
generated = generated.detach().cpu()
eoa_token_id = int(self.uni_prompting.sptids_dict['<|eoa|>'][0].item())
mask_token_id = int(self.mask_token_id)
token_list = []
for tok in generated.tolist():
tok = int(tok)
if tok < 0:
continue
if tok == eoa_token_id:
break
if tok == mask_token_id:
continue
if tok >= self.audio_codebook_size:
continue
token_list.append(tok)
if not token_list:
return None, "Generated sequence was empty after post-processing."
speech_units = "".join(f"<|speech_{tok}|>" for tok in token_list)
condition = self._resolve_condition(self.audio_condition_default)
fd, temp_path = tempfile.mkstemp(prefix="omada_s2s_", suffix=".wav")
os.close(fd)
try:
wav = self.vq_audio.decode(
speech_units,
condition=condition,
output_wav_file=temp_path,
)
finally:
try:
os.remove(temp_path)
except OSError:
pass
audio = (self.sample_rate, wav.astype(np.float32))
return audio, f"Speech response generated successfully. (voice: {condition})"
# ------------------------------------------------------------------
# Speech-to-Text
# ------------------------------------------------------------------
def run_s2t(
self,
audio_path: Optional[str],
steps: int,
block_length: int,
max_new_tokens: int,
remasking: str,
) -> Tuple[str, str]:
if not audio_path:
return "", "Please upload an audio file first."
tokens = self.vq_audio.encode(audio_path).to(self.device)
offset = self.text_vocab_size + self.speech_codebook
tokens = tokens + offset
spt = self.uni_prompting.sptids_dict
audio_block = torch.cat(
[
spt['<|s2t|>'].to(self.device).unsqueeze(0),
spt['<|soa|>'].to(self.device).unsqueeze(0),
tokens.to(self.device),
spt['<|eoa|>'].to(self.device).unsqueeze(0),
],
dim=1,
)
prompt_text = random.choice(S2T_INSTRUCTION)
chat_prompt = (
"<|start_header_id|>user<|end_header_id|>\n"
f"{prompt_text}"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
prompt_tensor = self.uni_prompting.text_tokenizer(
chat_prompt,
return_tensors="pt",
).input_ids.to(self.device)
input_ids = torch.cat([audio_block, prompt_tensor], dim=1)
with torch.no_grad():
output_ids = self.model.mmu_generate(
input_ids,
max_new_tokens=int(max_new_tokens),
steps=int(steps),
block_length=int(block_length),
remasking=str(remasking),
)
decoded = self.uni_prompting.text_tokenizer.batch_decode(
output_ids[:, input_ids.shape[1]:],
skip_special_tokens=True,
)[0]
return decoded.strip(), "Transcription generated successfully."
# ------------------------------------------------------------------
# Video-to-Text
# ------------------------------------------------------------------
def run_v2t(
self,
video_path: Any,
steps: int,
block_length: int,
max_new_tokens: int,
) -> Tuple[str, str]:
resolved_path, converted = self._prepare_video_path(video_path)
if not resolved_path:
return "", "Please upload or record a video file first."
try:
video_tokens = self._extract_video_tokens(resolved_path, num_frame=self.num_frames_v2t)
except Exception as exc:
return "", f"Failed to process video: {exc}"
spt = self.uni_prompting.sptids_dict
# Match training eval: fixed detailed-description question
question = "Please provide a detailed description of the video."
prompt = (
"<|start_header_id|>user<|end_header_id|>\n"
f"{question}"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
prompt_ids = self.uni_prompting.text_tokenizer(
prompt,
return_tensors="pt",
).input_ids.to(self.device)
input_ids = torch.cat(
[
spt['<|v2t|>'].to(self.device).unsqueeze(0),
spt['<|soi|>'].to(self.device).unsqueeze(0),
video_tokens,
spt['<|eoi|>'].to(self.device).unsqueeze(0),
spt['<|sot|>'].to(self.device).unsqueeze(0),
prompt_ids,
],
dim=1,
).long()
with torch.no_grad():
output_ids = self.model.mmu_generate(
input_ids,
max_new_tokens=int(max_new_tokens),
steps=int(steps),
block_length=int(block_length),
)
print("[V2T] input shape:", input_ids.shape)
print("[V2T] output shape:", output_ids.shape)
raw_all = self.uni_prompting.text_tokenizer.decode(output_ids[0], skip_special_tokens=False)
print("[V2T] RAW ALL:", repr(raw_all))
decoded = self.uni_prompting.text_tokenizer.batch_decode(
output_ids[:, input_ids.shape[1]:],
skip_special_tokens=True,
)[0]
print("[V2T] DECODED SLICE:", repr(decoded))
return decoded.strip(), "Video caption generated successfully."
# ------------------------------------------------------------------
# Text-to-Image
# ------------------------------------------------------------------
def run_t2i(
self,
prompt: str,
timesteps: int,
temperature: float,
guidance_scale: float,
) -> Tuple[Optional[Image.Image], str]:
if not prompt or not prompt.strip():
return None, "Please provide a text prompt."
image_seq_len = 1024
image_tokens = torch.full(
(1, image_seq_len),
self.mask_token_id,
dtype=torch.long,
device=self.device,
)
input_ids, attention_mask = self.uni_prompting(([prompt.strip()], image_tokens), "t2i_gen")
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
if guidance_scale > 0:
uncond_ids, uncond_mask = self.uni_prompting(([""], image_tokens.clone()), "t2i_gen")
uncond_ids = uncond_ids.to(self.device)
uncond_mask = uncond_mask.to(self.device)
else:
uncond_ids = None
uncond_mask = None
with torch.no_grad():
gen_tokens = self.model.t2i_generate(
input_ids=input_ids,
uncond_input_ids=uncond_ids,
attention_mask=attention_mask,
uncond_attention_mask=uncond_mask,
guidance_scale=float(guidance_scale),
temperature=float(temperature),
timesteps=int(timesteps),
noise_schedule=self.mask_schedule,
noise_type=self.noise_type,
predict_all_tokens=self.predict_all_tokens,
seq_len=image_seq_len,
mask_token_id=self.mask_token_id,
codebook_size=self.codebook_size,
uni_prompting=self.uni_prompting,
config=self.train_cfg,
)
if gen_tokens is None:
return None, "Image generation failed."
gen_tokens = torch.clamp(gen_tokens, min=0, max=self.codebook_size - 1)
image = self._decode_image_tokens(gen_tokens[0])
return image, "Image generated from text prompt."
# ------------------------------------------------------------------
# Image-to-Image Editing
# ------------------------------------------------------------------
def run_i2i(
self,
instruction: str,
source_image: Optional[Image.Image],
timesteps: int,
temperature: float,
guidance_scale: float,
) -> Tuple[Optional[Image.Image], str]:
if source_image is None:
return None, "Please upload a reference image."
if not instruction or not instruction.strip():
return None, "Provide editing instructions for the image."
try:
input_tokens = self._prepare_image_tokens(source_image, resolution=self.image_resolution)
except Exception as exc:
return None, f"Failed to encode input image: {exc}"
seq_len = int(input_tokens.shape[-1])
output_placeholder = torch.full(
(1, seq_len),
self.mask_token_id,
dtype=torch.long,
device=self.device,
)
input_ids, attention_mask = self.uni_prompting(
([instruction.strip()], input_tokens, output_placeholder),
"i2i_gen",
)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
uncond_ids = None
uncond_attn = None
if guidance_scale > 0:
uncond_ids, uncond_attn = self.uni_prompting(
([""], input_tokens.clone(), torch.full_like(output_placeholder, self.mask_token_id)),
"i2i_gen",
)
uncond_ids = uncond_ids.to(self.device)
uncond_attn = uncond_attn.to(self.device)
with torch.no_grad():
gen_tokens = self.model.i2i_generate(
input_ids=input_ids,
uncond_input_ids=uncond_ids,
attention_mask=attention_mask,
uncond_attention_mask=uncond_attn,
temperature=float(temperature),
timesteps=int(timesteps),
guidance_scale=float(guidance_scale),
noise_schedule=self.mask_schedule,
noise_type=self.noise_type,
seq_len=seq_len,
mask_token_id=self.mask_token_id,
codebook_size=self.codebook_size,
uni_prompting=self.uni_prompting,
config=self.train_cfg,
)
if gen_tokens is None:
return None, "Image editing failed."
gen_tokens = torch.clamp(gen_tokens, min=0, max=self.codebook_size - 1)
image = self._decode_image_tokens(gen_tokens[0])
return image, "Edited image generated."
# ------------------------------------------------------------------
# Video-to-Speech
# ------------------------------------------------------------------
def run_v2s(
self,
video_path: Any,
message: Optional[str],
max_new_tokens: int,
steps: int,
block_length: int,
temperature: float,
cfg_scale: float,
) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
resolved_path, converted = self._prepare_video_path(video_path)
if not resolved_path:
return None, "Please upload or record a video first."
try:
video_tokens = self._extract_video_tokens(resolved_path, num_frame=self.num_frames_v2s)
except Exception as exc:
return None, f"Failed to process video: {exc}"
prompt_body = message.strip() if message and message.strip() else random.choice(V2S_INSTRUCTION)
prompt_text = (
"<|start_header_id|>user<|end_header_id|>\n"
f"{prompt_body}"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
gen_len = max(1, int(max_new_tokens))
gen_len = min(gen_len, self.max_audio_len_short)
gen_len, steps, block_length = self._prepare_block_schedule(
gen_len,
steps,
block_length,
)
audio_placeholder = torch.full(
(1, gen_len),
self.mask_token_id,
dtype=torch.long,
device=self.device,
)
try:
seq_ids, attn_mask = self.uni_prompting(
(video_tokens, [prompt_text], [audio_placeholder]),
"v2s_gen",
)
except Exception as exc:
return None, f"Failed to build V2S prompt: {exc}"
input_ids = seq_ids.to(self.device)
attn_mask = attn_mask.to(self.device)
with torch.no_grad():
outputs = self.model.t2s_generate_mmu_like(
input_ids=input_ids,
max_new_tokens=gen_len,
steps=int(steps),
block_length=int(block_length),
temperature=float(temperature),
cfg_scale=float(cfg_scale),
mask_token_id=self.mask_token_id,
attention_mask=attn_mask,
uni_prompting=self.uni_prompting,
codebook_size=self.codebook_size,
audio_codebook_size=self.audio_codebook_size,
)
if not outputs:
return None, "Audio generation produced no tokens."
generated = outputs[0]
if isinstance(generated, torch.Tensor):
generated = generated.detach().cpu()
eoa_token_id = int(self.uni_prompting.sptids_dict['<|eoa|>'][0].item())
mask_token_id = int(self.mask_token_id)
token_list = []
for tok in generated.tolist():
tok = int(tok)
if tok < 0:
continue
if tok == eoa_token_id:
break
if tok == mask_token_id:
continue
if tok >= self.audio_codebook_size:
continue
token_list.append(tok)
if not token_list:
return None, "Generated sequence was empty after decoding."
speech_units = "".join(f"<|speech_{tok}|>" for tok in token_list)
fd, temp_path = tempfile.mkstemp(prefix="omada_v2s_", suffix=".wav")
os.close(fd)
condition = self._resolve_condition(self.audio_condition_default)
try:
wav = self.vq_audio.decode(
speech_units,
condition=condition,
output_wav_file=temp_path,
)
except Exception as exc:
return None, f"Failed to decode speech: {exc}"
finally:
try:
os.remove(temp_path)
except OSError:
pass
status = "Speech generated from video."
if converted:
status += " (Webcam recording converted to MP4.)"
status += f" (voice: {condition})"
return (self.sample_rate, wav.astype(np.float32)), status
# ------------------------------------------------------------------
# Image-to-Speech (This is a subset of s2s)
# ------------------------------------------------------------------
def run_i2s(
self,
image: Optional[Image.Image],
message: Optional[str],
max_new_tokens: int,
steps: int,
block_length: int,
temperature: float,
cfg_scale: float,
) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
if image is None:
return None, "Please upload an image first."
question = (message or "").strip() or "Please describe the image in spoken form."
caption, status = self._mmu_answer([image], question)
if not caption:
return None, status
speech_len, steps, block_length = self._prepare_block_schedule(
max_new_tokens,
steps,
block_length,
)
audio_result, speech_status = self.run_t2s(
caption,
speech_len,
steps,
block_length,
temperature,
cfg_scale,
'random',
'random',
'random',
'random',
)
if audio_result is None:
return None, speech_status
combined_status = f"{status} {speech_status}".strip()
return audio_result, combined_status or "Spoken description generated."
# ------------------------------------------------------------------
# Chat (Text Generation)
# ------------------------------------------------------------------
def run_chat(
self,
message: str,
max_new_tokens: int,
steps: int,
block_length: int,
temperature: float,
) -> Tuple[str, str]:
content = (message or "").strip()
if not content:
return "", "Type a message to start chatting."
prompt = (
"<|start_header_id|>user<|end_header_id|>\n"
f"{content}"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
tokenizer = self.uni_prompting.text_tokenizer
tokenizer.padding_side = "left"
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokens = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
)
input_ids = tokens["input_ids"].to(self.device)
attn_mask = tokens.get("attention_mask")
if attn_mask is not None:
attn_mask = attn_mask.to(self.device)
with torch.no_grad():
output_ids = self._generate_text_tokens(
input_ids,
max_new_tokens=int(max_new_tokens),
steps=int(steps),
block_length=int(block_length),
temperature=float(temperature),
cfg_scale=0.0,
attention_mask=attn_mask,
)
decoded = tokenizer.batch_decode(
output_ids[:, input_ids.shape[1]:],
skip_special_tokens=True,
)[0]
return decoded.strip(), "Assistant reply generated."
# ------------------------------------------------------------------
# General MMU (N Images → Text)
# ------------------------------------------------------------------
def run_mmu(
self,
images: Union[Optional[Image.Image], Sequence[Optional[Image.Image]]],
message: str,
max_new_tokens: int,
steps: int,
block_length: int,
temperature: float,
) -> Tuple[str, str]:
"""
MMU demo now consumes exactly one image. If callers pass a list (for
backwards compatibility), we keep only the first valid image.
"""
if isinstance(images, Image.Image):
normalized: List[Image.Image] = [images]
elif images is None:
normalized = []
else:
normalized = [img for img in images if img is not None]
if not normalized:
return "", "Please provide an image for MMU reasoning."
primary_image = normalized[0]
reply, status = self._mmu_answer(
[primary_image],
message,
max_new_tokens=max_new_tokens,
steps=steps,
block_length=block_length,
temperature=temperature,
)
return reply, status
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _resolve_choice(self, choice: Optional[str], options):
if choice is None or choice == 'random':
return random.choice(options)
return choice
def _sample_condition(self) -> str:
return (
f"gender-{random.choice(self.genders)}"
f"_emotion-{random.choice(self.emotions)}"
f"_speed-{random.choice(self.speeds)}"
f"_pitch-{random.choice(self.pitches)}"
)
def _resolve_condition(self, preferred: Optional[str] = None) -> str:
if preferred and preferred != "random":
if not self._valid_conditions or preferred in self._valid_conditions:
return preferred
if self._valid_conditions:
for _ in range(8):
candidate = self._sample_condition()
if candidate in self._valid_conditions:
return candidate
if self.audio_condition_default in self._valid_conditions:
return self.audio_condition_default
return next(iter(self._valid_conditions))
if preferred and preferred != "random":
return preferred
return self.audio_condition_default
def _format_chat_prompt(self, content: str) -> str:
clean = (content or "").strip()
if not clean:
clean = "Please describe the visual content."
return (
"<|start_header_id|>user<|end_header_id|>\n"
f"{clean}\n"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
def _prepare_block_schedule(
self,
total_tokens: int,
steps: int,
block_length: int,
) -> Tuple[int, int, int]:
total = max(1, int(total_tokens))
blk = max(1, min(int(block_length), total))
if total % blk != 0:
blk = math.gcd(total, blk)
blk = blk if blk > 0 else total
if total % blk != 0:
blk = total
num_blocks = max(1, total // blk)
steps = max(num_blocks, int(steps))
if steps % num_blocks != 0:
steps = num_blocks * math.ceil(steps / num_blocks)
return total, steps, blk
def _prepare_video_path(self, video_input: Any) -> Tuple[Optional[str], bool]:
"""Normalize Gradio video inputs (upload/webcam) to an MP4 filepath."""
candidate = None
if isinstance(video_input, str):
candidate = video_input
elif isinstance(video_input, dict):
candidate = (
video_input.get("video")
or video_input.get("name")
or video_input.get("path")
)
elif isinstance(video_input, (list, tuple)) and video_input:
candidate = str(video_input[0])
if not candidate:
return None, False
candidate = str(candidate)
if not self._ensure_file_ready(candidate):
return None, False
if candidate.lower().endswith(".mp4"):
return candidate, False
converted = self._convert_to_mp4(candidate)
if converted:
return converted, True
suffix = Path(candidate).suffix or ".webm"
fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix)
os.close(fd)
try:
shutil.copy2(candidate, tmp_path)
self._temp_video_files.append(tmp_path)
return tmp_path, False
except OSError:
try:
os.remove(tmp_path)
except OSError:
pass
return candidate, False
if candidate.lower().endswith(".mp4"):
return candidate, False
converted = self._convert_to_mp4(candidate)
if converted:
return converted, True
suffix = Path(candidate).suffix or ".webm"
fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix)
os.close(fd)
try:
shutil.copy2(candidate, tmp_path)
self._temp_video_files.append(tmp_path)
return tmp_path, False
except OSError:
try:
os.remove(tmp_path)
except OSError:
pass
return candidate, False
def _ensure_file_ready(self, path: str, retries: int = 8, delay: float = 0.2) -> bool:
"""Ensure the uploaded/recorded file is fully written before processing."""
prev_size = -1
for _ in range(retries):
try:
size = os.path.getsize(path)
except OSError:
size = -1
if size <= 0:
time.sleep(delay)
continue
if size == prev_size:
return True
prev_size = size
time.sleep(delay)
return prev_size > 0
def _convert_to_mp4(self, src_path: str) -> Optional[str]:
"""Convert arbitrary video file to MP4 using OpenCV (drops audio)."""
cap = cv2.VideoCapture(src_path)
if not cap.isOpened():
return None
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
fps = cap.get(cv2.CAP_PROP_FPS)
if width <= 0 or height <= 0:
cap.release()
return None
if not fps or np.isnan(fps) or fps <= 0:
fps = 24.0
fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_", suffix=".mp4")
os.close(fd)
writer = cv2.VideoWriter(
tmp_path,
cv2.VideoWriter_fourcc(*"mp4v"),
float(fps),
(width, height),
)
if not writer.isOpened():
cap.release()
try:
os.remove(tmp_path)
except OSError:
pass
return None
frame_count = 0
try:
while True:
ret, frame = cap.read()
if not ret:
break
writer.write(frame)
frame_count += 1
finally:
cap.release()
writer.release()
if frame_count == 0:
try:
os.remove(tmp_path)
except OSError:
pass
return None
self._temp_video_files.append(tmp_path)
return tmp_path
def _extract_video_tokens(self, video_path: str, num_frame=8) -> torch.Tensor:
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames <= 0:
cap.release()
raise RuntimeError(f"No readable frames in {video_path}")
indices = np.linspace(0, total_frames - 1, num_frame, dtype=int)
frames = []
for idx in range(total_frames):
ret, frame = cap.read()
if idx in indices and ret:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil = Image.fromarray(rgb)
frames.append(image_transform(pil, resolution=self.video_resolution))
cap.release()
if len(frames) == 0:
raise RuntimeError("Failed to sample frames for V2T inference.")
video_tensor = torch.stack(frames).to(self.device)
video_tokens = self.vq_image.get_code(video_tensor) + self.text_vocab_size
return video_tokens.long().to(self.device).view(1, -1)
def _prepare_image_tokens(self, image: Image.Image, resolution: Optional[int] = None) -> torch.LongTensor:
if image is None:
raise ValueError("Image input is required.")
target_resolution = int(resolution or self.image_resolution)
tensor = image_transform(image, resolution=target_resolution)
tensor = tensor.unsqueeze(0).to(self.device)
codes = self.vq_image.get_code(tensor) + self.text_vocab_size
return codes.long().to(self.device)
def _decode_image_tokens(self, tokens: torch.Tensor) -> Image.Image:
codes = tokens.view(1, -1).clamp(min=0, max=self.codebook_size - 1).to(self.device)
with torch.no_grad():
image_tensor = self.vq_image.decode_code(codes)
image_tensor = image_tensor.squeeze(0).cpu()
image_tensor = torch.clamp((image_tensor.float() + 1.0) / 2.0, min=0.0, max=1.0)
array = (image_tensor.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
return Image.fromarray(array)
def _mmu_answer(
self,
images: List[Image.Image],
question: str,
max_new_tokens: Optional[int] = None,
steps: Optional[int] = None,
block_length: Optional[int] = None,
temperature: Optional[float] = None,
) -> Tuple[str, str]:
if not images:
return "", "Please provide at least one image."
encoded_images: List[torch.Tensor] = []
for image in images:
if image is None:
continue
try:
tokens = self._prepare_image_tokens(
image,
resolution=480
).view(-1)
encoded_images.append(tokens)
except Exception:
continue
if not encoded_images:
return "", "Failed to encode the provided image(s)."
question = (question or "").strip() or "Describe the visual content."
prompt = self._format_chat_prompt(question)
try:
tokenized = self.uni_prompting.text_tokenizer(
[prompt],
add_special_tokens=False,
)["input_ids"][0]
except Exception as exc:
return "", f"Failed to tokenize question: {exc}"
try:
mmu_input_ids, prompt_masks, _ = self.uni_prompting.mmu_mult_prompt(
batch_image_ids_list=[encoded_images],
batch_text_ids=[tokenized],
)
except Exception as exc:
return "", f"Failed to construct MMU prompt: {exc}"
mmu_input_ids = mmu_input_ids.to(self.device)
prompt_masks = prompt_masks.to(self.device)
answer_tokens = int((prompt_masks == 0).sum(dim=1).max().item())
default_budget = max(1, answer_tokens) if answer_tokens > 0 else min(self.max_text_len, 256)
gen_tokens = int(max_new_tokens or default_budget)
requested_steps = steps if steps is not None else gen_tokens
requested_block = block_length if block_length is not None else max(1, gen_tokens // 2)
gen_tokens, steps, block_length = self._prepare_block_schedule(
gen_tokens,
requested_steps,
requested_block,
)
temperature = float(temperature if temperature is not None else 0.7)
if gen_tokens > 0:
mask_block = torch.full(
(mmu_input_ids.size(0), gen_tokens),
self.mask_token_id,
dtype=torch.long,
device=self.device,
)
mmu_input_ids = torch.cat([mmu_input_ids, mask_block], dim=1)
with torch.no_grad():
output_ids = self.model.mmu_generate(
mmu_input_ids,
max_new_tokens=int(gen_tokens),
steps=int(steps),
block_length=int(block_length),
temperature=temperature,
remasking="low_confidence",
mask_id=self.mask_token_id,
)
decoded = self.uni_prompting.text_tokenizer.batch_decode(
output_ids[:, mmu_input_ids.shape[1]:],
skip_special_tokens=True,
)[0].strip()
if not decoded:
return "", "MMU response was empty."
return decoded, "Image understanding succeeded."
def _generate_text_tokens(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int,
steps: int,
block_length: int,
temperature: float,
cfg_scale: float = 0.0,
attention_mask: Optional[torch.Tensor] = None,
remasking: str = "low_confidence",
) -> torch.Tensor:
prompt_ids = prompt_ids.to(self.device)
batch_size, prompt_len = prompt_ids.shape
gen_len, steps, block_length = self._prepare_block_schedule(
max_new_tokens,
steps,
block_length,
)
work = torch.full(
(batch_size, prompt_len + gen_len),
self.mask_token_id,
dtype=torch.long,
device=self.device,
)
work[:, :prompt_len] = prompt_ids
prompt_index = work != self.mask_token_id
attention_bias = None
if attention_mask is not None and (attention_mask == 0).any():
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
num_blocks = max(1, gen_len // block_length)
inner_steps = max(1, steps // num_blocks)
for block_idx in range(num_blocks):
block_slice = slice(prompt_len + block_idx * block_length, prompt_len + (block_idx + 1) * block_length)
block_mask_index = work[:, block_slice] == self.mask_token_id
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, inner_steps)
for inner_step in range(inner_steps):
mask_index = work == self.mask_token_id
if cfg_scale > 0.0:
unconditional = work.clone()
unconditional[prompt_index] = self.mask_token_id
model_input = torch.cat([work, unconditional], dim=0)
logits = self.model(model_input).logits
cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits)
else:
logits = self.model(work, attention_bias=attention_bias).logits
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)
if remasking == "low_confidence":
probs = F.softmax(logits.to(torch.float64), dim=-1)
x0_p = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
elif remasking == "random":
x0_p = torch.rand_like(x0, dtype=torch.float64)
else:
raise NotImplementedError(remasking)
x0_p[:, prompt_len + (block_idx + 1) * block_length :] = -np.inf
x0 = torch.where(mask_index, x0, work)
confidence = torch.where(mask_index, x0_p, torch.full_like(x0_p, float('-inf')))
transfer_index = torch.zeros_like(work, dtype=torch.bool)
for b in range(batch_size):
k = int(num_transfer_tokens[b, inner_step].item())
if k <= 0:
continue
values, select_idx = torch.topk(confidence[b], k=k)
transfer_index[b, select_idx] = values != float('-inf')
work[transfer_index] = x0[transfer_index]
return work
def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optional[int]):
theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
with gr.Blocks(title="AIDAS Lab @ SNU", css=CUSTOM_CSS, theme=theme, js=FORCE_LIGHT_MODE_JS) as demo:
with gr.Column(elem_classes=["omada-header"]):
if LOGO_DATA_URI:
gr.HTML(
f"<div class='omada-logo-wrapper'><img src=\"{LOGO_DATA_URI}\" alt=\"AIDAS Lab\" class=\"omada-logo-img\" /></div>"
)
elif LOGO_PATH.exists():
gr.Image(
value=str(LOGO_PATH),
show_label=False,
height=140,
interactive=False,
elem_classes=["omada-logo"],
)
gr.Markdown(
"## Omni-modal Diffusion Foundation Model\n### Pretrained Demo",
elem_classes=["omada-title"],
)
gr.HTML(
"<p class='omada-tagline'>"
"<span class='tagline-speech'>Create speech</span>, "
"<span class='tagline-audio'>transcribe audio</span>, "
"<span class='tagline-video'>describe video</span>, "
"<span class='tagline-text'>chat with text</span>, and "
"<span class='tagline-image'>generate or edit images</span> — all from a single model. "
"Use the advanced sections when you want tighter control."
"</p>")
group_to_modes = {
"Speech": ["Text → Speech", "Speech → Speech", "Speech → Text"],
"Video": ["Video → Text", "Video → Speech"],
"Image": ["Text → Image", "Image Editing"],
"Multi-Modal": ["MMU (Image → Text)"],
"Text": ["Text"],
}
default_group = "Speech"
default_mode = group_to_modes[default_group][0]
placeholder_map = {
"Text → Speech": "Type the speech you want to generate...",
"Speech → Speech": "Optionally add context for the reply...",
"Speech → Text": "Upload audio on the right, then leave notes here if needed.",
"Video → Speech": "Upload video on the right. Optionally provide guidance here.",
"Video → Text": "Upload video on the right, then leave notes here if needed.",
"Text": "Ask anything and the assistant will reply with text.",
"MMU (Image → Text)": "Ask a question about the uploaded image.",
"Text → Image": "Describe the image you want to generate...",
"Image Editing": "Describe how you want to edit the uploaded image...",
}
with gr.Row(elem_classes=["omada-layout"], equal_height=False):
with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]):
chatbox = gr.Chatbot(label="Session", height=760, sanitize_html=False)
chat_input = gr.Textbox(
label="Message",
placeholder=placeholder_map[default_mode],
lines=3,
)
with gr.Row():
send_button = gr.Button("Send", variant="primary")
clear_button = gr.Button("Clear", variant="secondary")
with gr.Column(scale=2, min_width=360, elem_classes=["omada-controls"]):
task_group_selector = gr.Radio(
list(group_to_modes.keys()),
value=default_group,
label="Task Group",
elem_classes=["omada-task-selector", "omada-task-group-buttons"],
)
submode_selector = gr.Radio(
group_to_modes[default_group],
value=default_mode,
label="Task Mode",
elem_classes=["omada-task-selector", "omada-task-mode-buttons"],
)
with gr.Column(visible=True, elem_classes=["omada-mode-panel"]) as t2s_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Text-to-Speech Controls")
with gr.Group(elem_classes=["omada-advanced"]):
gr.Markdown("**Generation**")
with gr.Row():
t2s_max_tokens = gr.Slider(2, 512, value=384, label="Speech token length", step=2)
t2s_steps = gr.Slider(2, 512, value=128, label="Total refinement steps", step=2)
with gr.Row():
t2s_block = gr.Slider(2, 512, value=128, label="Block length", step=2)
t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, label="CFG scale", step=0.1)
t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
with gr.Group(elem_classes=["omada-advanced"]):
gr.Markdown("**Voice styling**")
with gr.Row():
t2s_gender = gr.Dropdown(['random'] + app.genders, value='random', label="Voice gender")
t2s_emotion = gr.Dropdown(['random'] + app.emotions, value='random', label="Emotion")
with gr.Row():
t2s_speed = gr.Dropdown(['random'] + app.speeds, value='random', label="Speaking speed")
t2s_pitch = gr.Dropdown(['random'] + app.pitches, value='random', label="Pitch")
if T2S_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample prompts**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=T2S_EXAMPLES,
inputs=[chat_input],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2s_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Speech-to-Speech Controls")
s2s_audio = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"])
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
s2s_max_tokens = gr.Slider(2, 512, value=512, label="Reply token length", step=2)
with gr.Row():
s2s_steps = gr.Slider(2, 512, value=512, label="Refinement steps", step=2)
s2s_block = gr.Slider(2, 512, value=512, label="Block length", step=2)
with gr.Row():
s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, label="Sampling temperature", step=0.05)
s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, label="CFG scale", step=0.1)
if S2S_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample S2S clips**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=S2S_EXAMPLES,
inputs=[s2s_audio],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2t_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Speech-to-Text Controls")
s2t_audio = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
with gr.Row():
s2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2)
s2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2)
s2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2)
s2t_remasking = gr.Dropdown(
choices=["low_confidence", "random"],
value="low_confidence",
label="Remasking strategy",
)
if S2T_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample clips**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=S2T_EXAMPLES,
inputs=[s2t_audio],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2t_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Video-to-Text Controls")
v2t_video = gr.Video(
label="Upload or record video",
format=None,
height=256,
sources=["upload", "webcam"],
)
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
with gr.Row():
v2t_steps = gr.Slider(2, 512, value=256, label="Denoising steps", step=2)
v2t_block = gr.Slider(2, 512, value=256, label="Block length", step=2)
v2t_max_tokens = gr.Slider(2, 512, value=256, label="Max tokens", step=2)
if V2T_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample videos**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=V2T_EXAMPLES,
inputs=[v2t_video],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2s_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Video-to-Speech Controls")
v2s_video = gr.Video(
label="Upload or record video",
format=None,
height=256,
sources=["upload", "webcam"],
)
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
v2s_max_tokens = gr.Slider(2, 512, value=256, label="Reply token length", step=2)
with gr.Row():
v2s_steps = gr.Slider(2, 512, value=256, label="Refinement steps", step=2)
v2s_block = gr.Slider(2, 512, value=256, label="Block length", step=2)
with gr.Row():
v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, label="CFG scale", step=0.1)
if V2S_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample videos**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=V2S_EXAMPLES,
inputs=[v2s_video],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as i2s_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Image-to-Speech Controls")
i2s_image = gr.Image(type="pil", label="Upload image", sources=["upload"])
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
i2s_max_tokens = gr.Slider(2, 512, value=256, label="Reply token length", step=2)
with gr.Row():
i2s_steps = gr.Slider(2, 512, value=256, label="Refinement steps", step=2)
i2s_block = gr.Slider(2, 512, value=256, label="Block length", step=2)
with gr.Row():
i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, label="CFG scale", step=0.1)
if I2S_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample images**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=I2S_EXAMPLES,
inputs=[i2s_image],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as image_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Image Tasks")
image_mode_selector = gr.Radio(
["Generation", "Editing"],
value="Generation",
label="Sub-mode",
)
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"], visible=True) as t2i_settings:
t2i_timesteps = gr.Slider(4, 128, value=max(4, min(128, app.t2i_default_timesteps)), label="Timesteps", step=2)
t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1)
with gr.Accordion("Editing settings", open=True, elem_classes=["omada-advanced"], visible=False) as i2i_settings:
i2i_image = gr.Image(type="pil", label="Reference image", sources=["upload"])
i2i_timesteps = gr.Slider(4, 128, value=max(4, min(128, app.i2i_default_timesteps)), label="Timesteps", step=2)
i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1)
if T2I_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample prompts**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=T2I_EXAMPLES,
inputs=[chat_input],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as chat_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Chat Controls")
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
chat_max_tokens = gr.Slider(2, 512, value=64, label="Reply max tokens", step=2)
with gr.Row():
chat_steps = gr.Slider(2, 512, value=64, label="Refinement steps", step=2)
chat_block = gr.Slider(2, 512, value=64, label="Block length", step=2)
chat_temperature = gr.Slider(0.0, 2.0, value=0.8, label="Sampling temperature", step=0.05)
if CHAT_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample prompts**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=CHAT_EXAMPLES,
inputs=[chat_input],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as mmu_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Image Reasoning")
mmu_image = gr.Image(type="pil", label="Image", sources=["upload"])
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
mmu_max_tokens = gr.Slider(2, 512, value=256, label="Answer max tokens", step=2)
with gr.Row():
mmu_steps = gr.Slider(2, 512, value=256, label="Refinement steps", step=2)
mmu_block = gr.Slider(2, 512, value=128, label="Block length", step=2)
mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, label="Sampling temperature", step=0.05)
if MMU_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample prompts**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=MMU_EXAMPLES,
inputs=[mmu_image, chat_input],
examples_per_page=1,
)
def _panel_updates(group: str, mode: str):
show_t2s = group == "Speech" and mode == "Text → Speech"
show_s2s = group == "Speech" and mode == "Speech → Speech"
show_v2s = group == "Video" and mode == "Video → Speech"
show_i2s = False
show_s2t = group == "Speech" and mode == "Speech → Text"
show_v2t = group == "Video" and mode == "Video → Text"
show_chat = group == "Text" and mode == "Text"
show_mmu = group == "Multi-Modal" and mode == "MMU (Image → Text)"
show_image = group == "Image" and mode in ("Text → Image", "Image Editing")
placeholder = placeholder_map.get(mode, chat_input.placeholder)
image_mode_value = "Generation" if mode == "Text → Image" else "Editing"
t2i_visible = show_image and mode == "Text → Image"
i2i_visible = show_image and mode == "Image Editing"
image_mode_update = gr.update(value=image_mode_value) if show_image else gr.update()
return (
gr.update(placeholder=placeholder),
gr.update(visible=show_t2s),
gr.update(visible=show_s2s),
gr.update(visible=show_v2s),
gr.update(visible=show_i2s),
gr.update(visible=show_s2t),
gr.update(visible=show_v2t),
gr.update(visible=show_chat),
gr.update(visible=show_mmu),
gr.update(visible=show_image),
image_mode_update,
gr.update(visible=t2i_visible),
gr.update(visible=i2i_visible),
)
def _on_group_change(group: str):
default_mode_local = group_to_modes[group][0]
submode_update = gr.update(choices=group_to_modes[group], value=default_mode_local)
panel_updates = _panel_updates(group, default_mode_local)
return (submode_update, *panel_updates)
def _on_submode_change(mode: str, group: str):
return _panel_updates(group, mode)
task_group_selector.change(
_on_group_change,
inputs=[task_group_selector],
outputs=[
submode_selector,
chat_input,
t2s_panel,
s2s_panel,
v2s_panel,
i2s_panel,
s2t_panel,
v2t_panel,
chat_panel,
mmu_panel,
image_panel,
image_mode_selector,
t2i_settings,
i2i_settings,
],
)
submode_selector.change(
_on_submode_change,
inputs=[submode_selector, task_group_selector],
outputs=[
chat_input,
t2s_panel,
s2s_panel,
v2s_panel,
i2s_panel,
s2t_panel,
v2t_panel,
chat_panel,
mmu_panel,
image_panel,
image_mode_selector,
t2i_settings,
i2i_settings,
],
)
def _toggle_image_task(task: str):
return (
gr.update(visible=task == "Generation"),
gr.update(visible=task == "Editing"),
)
image_mode_selector.change(
_toggle_image_task,
inputs=[image_mode_selector],
outputs=[t2i_settings, i2i_settings],
)
def _chat_handler(
history,
message,
group,
mode,
s2t_audio_path,
v2t_video_path,
t2s_max_tokens,
t2s_steps,
t2s_block,
t2s_temperature,
t2s_cfg,
t2s_gender,
t2s_emotion,
t2s_speed,
t2s_pitch,
s2t_steps,
s2t_block,
s2t_max_tokens,
s2t_remasking,
v2t_steps,
v2t_block,
v2t_max_tokens,
s2s_audio_path,
s2s_max_tokens,
s2s_steps,
s2s_block,
s2s_temperature,
s2s_cfg,
i2s_image,
i2s_max_tokens,
i2s_steps,
i2s_block,
i2s_temperature,
i2s_cfg,
image_mode,
t2i_timesteps,
t2i_temperature,
t2i_guidance,
i2i_image,
i2i_timesteps,
i2i_temperature,
i2i_guidance,
v2s_video_path,
v2s_max_tokens,
v2s_steps,
v2s_block,
v2s_temperature,
v2s_cfg,
chat_max_tokens,
chat_steps,
chat_block,
chat_temperature,
mmu_image,
mmu_max_tokens,
mmu_steps,
mmu_block,
mmu_temperature,
):
history = history or []
message = (message or "").strip()
response = ""
if mode == "Text → Speech":
if not message:
status = "Please type some text for speech synthesis."
response = _render_text_message(status, "")
else:
audio_result, status = app.run_t2s(
message,
t2s_max_tokens,
t2s_steps,
t2s_block,
t2s_temperature,
t2s_cfg,
t2s_gender,
t2s_emotion,
t2s_speed,
t2s_pitch,
)
response = _render_audio_message(status, audio_result)
display_user_raw = message or "[Speech request]"
elif mode == "Speech → Speech":
audio_result, status = app.run_s2s(
s2s_audio_path,
s2s_max_tokens,
s2s_steps,
s2s_block,
s2s_temperature,
s2s_cfg,
)
response = _render_audio_message(status, audio_result)
display_user_raw = message or "[Speech-to-speech request]"
elif mode == "Video → Speech":
audio_result, status = app.run_v2s(
v2s_video_path,
message,
v2s_max_tokens,
v2s_steps,
v2s_block,
v2s_temperature,
v2s_cfg,
)
response = _render_audio_message(status, audio_result)
display_user_raw = message or "[Video-to-speech request]"
elif mode == "Speech → Text":
if not s2t_audio_path:
status = "Please upload or record an audio clip first."
response = _render_text_message(status, "")
else:
transcript, status = app.run_s2t(
s2t_audio_path,
s2t_steps,
s2t_block,
s2t_max_tokens,
s2t_remasking,
)
response = _render_text_message(status, transcript)
display_user_raw = message or "[Audio transcription request]"
elif mode == "Video → Text":
if not v2t_video_path:
status = "Please upload or record a video first."
response = _render_text_message(status, "")
else:
caption, status = app.run_v2t(
v2t_video_path,
v2t_steps,
v2t_block,
v2t_max_tokens,
)
response = _render_text_message(status, caption)
display_user_raw = message or "[Video caption request]"
elif mode == "Text":
reply, status = app.run_chat(
message,
chat_max_tokens,
chat_steps,
chat_block,
chat_temperature,
)
response = _render_text_message(status, reply)
display_user_raw = message or "[Text request]"
elif mode == "MMU (Image → Text)":
reply, status = app.run_mmu(
[mmu_image] if mmu_image is not None else [],
message,
mmu_max_tokens,
mmu_steps,
mmu_block,
mmu_temperature,
)
response = _render_text_message(status, reply)
display_user_raw = message or "[Multi-image question]"
elif mode == "Text → Image":
if not message:
status = "Please provide a prompt for image generation."
response = _render_text_message(status, "")
else:
image_result, status = app.run_t2i(
message,
t2i_timesteps,
t2i_temperature,
t2i_guidance,
)
response = _render_image_message(status, image_result)
display_user_raw = message or "[Image generation request]"
elif mode == "Image Editing":
image_result, status = app.run_i2i(
message,
i2i_image,
i2i_timesteps,
i2i_temperature,
i2i_guidance,
)
response = _render_image_message(status, image_result)
display_user_raw = message or "[Image editing request]"
else:
status = f"Mode '{mode}' is not supported."
response = _render_text_message(status, "")
display_user_raw = message or "[Unsupported mode]"
display_user = _format_user_message(display_user_raw)
history = history + [(display_user, response)]
return history, ""
submit_inputs = [
chatbox,
chat_input,
task_group_selector,
submode_selector,
s2t_audio,
v2t_video,
t2s_max_tokens,
t2s_steps,
t2s_block,
t2s_temperature,
t2s_cfg,
t2s_gender,
t2s_emotion,
t2s_speed,
t2s_pitch,
s2t_steps,
s2t_block,
s2t_max_tokens,
s2t_remasking,
v2t_steps,
v2t_block,
v2t_max_tokens,
s2s_audio,
s2s_max_tokens,
s2s_steps,
s2s_block,
s2s_temperature,
s2s_cfg,
i2s_image,
i2s_max_tokens,
i2s_steps,
i2s_block,
i2s_temperature,
i2s_cfg,
image_mode_selector,
t2i_timesteps,
t2i_temperature,
t2i_guidance,
i2i_image,
i2i_timesteps,
i2i_temperature,
i2i_guidance,
v2s_video,
v2s_max_tokens,
v2s_steps,
v2s_block,
v2s_temperature,
v2s_cfg,
chat_max_tokens,
chat_steps,
chat_block,
chat_temperature,
mmu_image,
mmu_max_tokens,
mmu_steps,
mmu_block,
mmu_temperature,
]
submit_outputs = [chatbox, chat_input]
chat_input.submit(_chat_handler, inputs=submit_inputs, outputs=submit_outputs)
send_button.click(_chat_handler, inputs=submit_inputs, outputs=submit_outputs)
def _clear_session():
return (
[],
"",
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
)
clear_button.click(
_clear_session,
inputs=None,
outputs=[
chatbox,
chat_input,
s2t_audio,
v2t_video,
s2s_audio,
i2s_image,
v2s_video,
i2i_image,
mmu_image,
],
)
demo.launch(
share=share,
server_name=server_name,
server_port=server_port,
)
def parse_args():
parser = argparse.ArgumentParser(description="OMada Gradio demo for audio/video tasks")
parser.add_argument("--train-config", required=True, help="Path to the training YAML used to build tokenizer + VQ modules")
parser.add_argument("--checkpoint", required=True, help="Path to an `unwrapped_model` directory")
parser.add_argument("--device", default=None, help="Override device (e.g. cuda:0). Defaults to CUDA if available")
parser.add_argument("--share", action="store_true", help="Enable public Gradio share link")
parser.add_argument("--server-name", default="0.0.0.0", help="Host address for Blocks.launch")
parser.add_argument("--server-port", type=int, default=None, help="Port for Blocks.launch")
return parser.parse_args()
def main():
args = parse_args()
app = OmadaDemo(args.train_config, args.checkpoint, args.device)
build_demo(app, args.share, args.server_name, args.server_port)
if __name__ == "__main__":
main()