AIDAS-Omni-Modal-Diffusion / MMaDA /inference /gradio_multimodal_demo2.py
jaeikkim
Reinit Space without binary assets
7bfbdc3
#!/usr/bin/env python3
"""
Gradio demo for OMada Stage 1.3 checkpoints covering:
* Text-to-Speech (T2S)
* Speech-to-Text (S2T)
* Video-to-Text (V2T)
The implementation wraps the existing CLI inference helpers so that a single
checkpoint directory (…/checkpoint-XXXX/unwrapped_model) can be previewed
interactively. Usage:
python MMaDA/inference/gradio_multimodal_demo2.py --train-config MMaDA/configs/omada_pretraining_stage1-3.yaml --checkpoint ../ckpt/checkpoint-315000/unwrapped_model/ --share
If you need remote access, pass `--share` which simply forwards the flag to
`gradio.Blocks.launch`.
"""
import argparse
import base64
import html
import io
import os
import random
import sys
import tempfile
import wave
from pathlib import Path
import shutil
import time
from typing import Any, Optional, Tuple
import numpy as np
CUSTOM_CSS = """
:root {
--omada-primary: #1e3a8a;
--omada-accent: #1d4ed8;
--omada-surface: #f3f4f6;
--omada-surface-alt: #ffffff;
--omada-border: #d0d7e5;
--omada-text-primary: #111827;
--omada-text-muted: #374151;
color-scheme: light;
}
html, body, body.dark, html.dark {
background: var(--omada-surface) !important;
color: var(--omada-text-primary) !important;
}
.gradio-container {
background: var(--omada-surface);
color: var(--omada-text-primary);
}
.omada-page-heading {
margin-bottom: 0;
color: var(--omada-text-primary);
}
.omada-tab-intro p {
font-size: 0.95rem;
color: var(--omada-text-primary);
margin-top: 0;
opacity: 0.9;
}
.omada-card {
background: var(--omada-surface-alt);
border-radius: 16px !important;
padding: 18px !important;
box-shadow: none;
border: 1px solid var(--omada-border);
color: var(--omada-text-primary);
}
.omada-card .gradio-slider .wrap-inner {
gap: 6px;
}
.omada-card .gradio-slider input[type=range]::-webkit-slider-thumb {
background: var(--omada-primary);
}
.gradio-slider input[type=range]::-webkit-slider-runnable-track {
background: rgba(14, 33, 80, 0.2);
}
.omada-section-title p {
text-transform: uppercase;
font-size: 0.78rem;
letter-spacing: 0.14em;
color: rgba(30, 58, 138, 0.85);
margin: 0 0 12px 0;
}
.omada-output .gradio-audio, .omada-output .gradio-textbox {
margin-top: 12px;
}
.gradio-textbox, .gradio-dropdown, .gradio-slider, .gradio-audio {
color: var(--omada-text-primary) !important;
}
.gradio-dropdown .wrap .label, .gradio-textbox label, .gradio-slider label {
color: var(--omada-text-primary);
}
.gradio-dropdown .single-select, .gradio-textbox textarea {
background: #ffffff !important;
border: 1px solid var(--omada-border) !important;
color: var(--omada-text-primary) !important;
}
.gradio-dropdown .single-select select, .gradio-textbox textarea {
color: var(--omada-text-primary) !important;
}
.gradio-textbox textarea::placeholder {
color: rgba(148, 163, 184, 0.65);
}
.gradio-dropdown, .gradio-textbox, .gradio-audio, .gradio-video, .gradio-slider {
background: #ffffff !important;
border: 1px solid var(--omada-border) !important;
border-radius: 12px !important;
}
.full-width-button button {
width: 100%;
background: var(--omada-primary) !important;
color: white !important;
border: none !important;
font-weight: 600;
transition: transform 0.2s ease, box-shadow 0.2s ease;
box-shadow: 0 12px 30px -12px rgba(79, 70, 229, 0.65);
}
.full-width-button button:hover {
transform: translateY(-1px);
box-shadow: 0 18px 34px -14px rgba(79, 70, 229, 0.75);
}
.omada-advanced .gr-accordion-header {
font-size: 0.85rem;
letter-spacing: 0.05em;
color: var(--omada-text-muted);
}
.omada-advanced .gr-accordion {
border: 1px solid var(--omada-border);
border-radius: 12px;
background: #ffffff;
}
.gradio-tabs {
background: transparent;
}
.gradio-tabs ul.tab-list {
background: transparent;
border-bottom: 1px solid var(--omada-border);
}
.gradio-tabs button {
color: var(--omada-text-primary);
}
.gradio-tabs button.selected {
color: var(--omada-text-primary);
background: rgba(14, 33, 80, 0.1);
border-bottom: 2px solid var(--omada-primary);
}
.gradio-container .label {
background: rgba(30, 58, 138, 0.1) !important;
color: var(--omada-primary) !important;
border: 1px solid rgba(30, 58, 138, 0.25) !important;
border-radius: 999px !important;
padding: 4px 12px !important;
}
.gradio-button.primary {
background: var(--omada-primary) !important;
color: #ffffff !important;
border: 1px solid var(--omada-primary) !important;
}
.gradio-accordion {
box-shadow: none;
}
.omada-layout {
gap: 20px !important;
}
.omada-chat-column {
gap: 12px !important;
}
.omada-chat-column .gradio-chatbot {
border-radius: 16px;
box-shadow: none;
border: 1px solid var(--omada-border);
background: #ffffff;
}
.omada-controls {
gap: 16px !important;
}
.omada-mode-panel {
display: flex;
flex-direction: column;
gap: 16px !important;
}
.omada-examples-card {
padding-top: 10px !important;
}
.omada-output-panel .gradio-audio,
.omada-output-panel .gradio-textbox {
margin-top: 8px;
}
.omada-response-container {
display: flex;
flex-direction: column;
gap: 10px;
}
.omada-response-status {
margin: 0;
font-weight: 600;
font-size: 0.95rem;
color: var(--omada-text-primary);
}
.omada-response-block,
.omada-audio-block {
background: rgba(30, 58, 138, 0.05);
border-radius: 12px;
padding: 12px 14px;
color: var(--omada-text-primary);
white-space: pre-wrap;
word-break: break-word;
}
.omada-audio-block audio {
width: 100%;
}
.omada-header {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 18px !important;
margin-bottom: 18px;
text-align: center;
}
.omada-header .gradio-image {
background: transparent !important;
border: none !important;
}
.omada-header img {
object-fit: contain;
display: block;
}
.omada-logo {
max-width: 180px;
padding: 0 !important;
}
.omada-examples {
margin-top: 8px;
padding-top: 4px;
}
.omada-logo .gradio-image,
.omada-logo .gradio-image > div,
.omada-logo .gradio-image .container {
background: transparent !important;
border: none !important;
box-shadow: none !important;
padding: 0 !important;
display: flex;
justify-content: center;
align-items: center;
}
.omada-logo img {
width: 100%;
height: auto;
}
.omada-logo button {
display: none !important;
}
.gradio-container .gradio-component,
.gradio-container .gradio-panel,
.gradio-container .gradio-box {
background: transparent !important;
color: var(--omada-text-primary);
}
.dark .gradio-container,
.dark .gradio-interface,
.dark .gradio-container * {
background-color: inherit !important;
color: var(--omada-text-primary) !important;
}
.dark .gradio-container .gradio-chatbot,
.dark .gradio-container .gradio-dropdown,
.dark .gradio-container .gradio-textbox,
.dark .gradio-container .gradio-audio,
.dark .gradio-container .gradio-video,
.dark .gradio-container .gradio-slider,
.dark .gradio-container .gradio-accordion,
.dark .gradio-container .gradio-panel,
.dark .gradio-container .gradio-box {
background: #ffffff !important;
border-color: var(--omada-border) !important;
}
.omada-title h2 {
font-size: 2.4rem;
font-weight: 700;
color: var(--omada-text-primary);
margin: 0;
}
.omada-title h3 {
font-size: 1.25rem;
font-weight: 600;
letter-spacing: 0.1em;
text-transform: uppercase;
color: var(--omada-text-muted);
margin: 6px 0 0;
}
.omada-tagline p {
color: var(--omada-text-primary);
font-size: 1rem;
margin: 0;
opacity: 0.9;
}
.gradio-container .prose :where(h1, h2, h3, h4, h5, h6) {
color: var(--omada-text-primary) !important;
}
.gradio-container .prose :where(p, li) {
color: var(--omada-text-muted) !important;
}
.gradio-container label, .gradio-container span, .gradio-container button {
color: var(--omada-text-primary);
}
.gradio-container .dark {
background: #ffffff !important;
color: var(--omada-text-primary) !important;
}
.omada-logo-img {
max-width: 250px;
width: 100%;
height: auto;
display: block;
margin: 0 auto;
}
.omada-logo-wrapper {
display: flex;
justify-content: center;
align-items: center;
}
"""
DEMO_ROOT = Path(__file__).resolve().parent / "demo"
LOGO_PATH = DEMO_ROOT / "logo.png"
T2S_TEXT_PATH = DEMO_ROOT / "t2s" / "text.txt"
def _load_logo_data() -> Optional[str]:
if not LOGO_PATH.exists():
return None
try:
import base64
except ImportError:
return str(LOGO_PATH)
try:
encoded = base64.b64encode(LOGO_PATH.read_bytes()).decode("utf-8")
except OSError:
return str(LOGO_PATH)
return f"data:image/png;base64,{encoded}"
def _load_t2s_examples():
if not T2S_TEXT_PATH.exists():
return []
lines = [
line.strip()
for line in T2S_TEXT_PATH.read_text(encoding="utf-8").splitlines()
if line.strip()
]
return [[line] for line in lines]
def _load_media_examples(subdir: str, suffixes):
target_dir = DEMO_ROOT / subdir
if not target_dir.exists():
return []
examples = []
for path in sorted(target_dir.iterdir()):
if path.is_file() and path.suffix.lower() in suffixes:
examples.append([str(path)])
return examples
T2S_EXAMPLES = _load_t2s_examples()
S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
LOGO_DATA_URI = _load_logo_data()
def _render_response(status: str, body_html: str = "") -> str:
safe_status = html.escape(status or "")
parts = []
if safe_status:
parts.append(f"<p class='omada-response-status'>{safe_status}</p>")
if body_html:
parts.append(body_html)
content = "".join(parts)
return f"<div class='omada-response-container'>{content}</div>"
def _render_text_message(status: str, content: Optional[str]) -> str:
content = (content or "").strip()
if not content:
return _render_response(status)
safe_content = html.escape(content).replace("\n", "<br>")
body = f"<div class='omada-response-block'>{safe_content}</div>"
return _render_response(status, body)
def _render_audio_message(status: str, audio: Optional[Tuple[int, np.ndarray]]) -> str:
"""Render an inline HTML audio player for chat responses."""
if not audio:
return _render_response(status)
sample_rate, data = audio
if data is None:
return _render_response(status)
waveform = np.asarray(data, dtype=np.float32)
if waveform.size == 0:
return _render_response(status)
if waveform.ndim == 1:
waveform = waveform[:, None]
channels = waveform.shape[1]
clipped = np.clip(waveform, -1.0, 1.0)
pcm16 = (clipped * 32767.0).astype(np.int16)
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_writer:
wav_writer.setnchannels(channels)
wav_writer.setsampwidth(2) # 16-bit PCM
wav_writer.setframerate(int(sample_rate))
wav_writer.writeframes(pcm16.tobytes())
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
audio_tag = (
"<div class='omada-audio-block'>"
"<audio controls preload='auto' playsinline>"
f"<source src='data:audio/wav;base64,{encoded}' type='audio/wav' /></audio>"
"</div>"
)
return _render_response(status, audio_tag)
def _format_user_message(message: str) -> str:
clean = html.escape(message or "")
return clean.replace("\n", "<br>")
# Ensure project modules (models, training, inference.common, …) are importable when the
# script is launched directly via `python MMaDA/inference/gradio_multimodal_demo.py`.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from inference.common import (
build_uni_prompting,
get_vq_model_audio,
get_vq_model_image,
load_omada_from_checkpoint,
load_train_config,
)
from models import get_mask_schedule
from training.data import S2T_INSTRUCTION, T2S_INSTRUCTION, V2T_INSTRUCTION
from training.utils import image_transform
def _resolve_noise_schedule(train_cfg) -> callable:
"""Return the diffusion noise schedule used for T2S sampling."""
schedule_cfg = getattr(train_cfg, "mask_schedule", None)
if schedule_cfg and hasattr(schedule_cfg, "schedule"):
schedule_name = schedule_cfg.schedule
schedule_kwargs = schedule_cfg.get("params", {})
return get_mask_schedule(schedule_name, **schedule_kwargs)
schedule_name = train_cfg.training.get("mask_schedule", "cosine")
return get_mask_schedule(schedule_name)
class OmadaDemo:
"""Lightweight container that loads all inference assets once."""
def __init__(self, train_config: str, checkpoint: str, device: Optional[str] = None):
ckpt_path = Path(checkpoint)
if ckpt_path.name != "unwrapped_model":
raise ValueError(
"`--checkpoint` must point to an `unwrapped_model` directory. "
f"Received: {checkpoint}"
)
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.train_cfg = load_train_config(train_config)
self.uni_prompting, _ = build_uni_prompting(self.train_cfg)
# Core models
self.model = load_omada_from_checkpoint(str(ckpt_path), self.device)
self.vq_audio = get_vq_model_audio(self.train_cfg, self.device)
self.vq_image = get_vq_model_image(self.train_cfg, self.device)
self.model.eval()
self.vq_audio.eval()
self.vq_image.eval()
# Cached constants
self.mask_token_id = int(self.model.config.mask_token_id)
self.noise_schedule = _resolve_noise_schedule(self.train_cfg)
self.sample_rate = int(getattr(self.vq_audio.u2s_config.data, "sampling_rate", 22050))
self.genders = ['female', 'male']
self.emotions = ['angry', 'happy', 'neutral', 'sad']
self.speeds = ['normal', 'fast', 'slow']
self.pitches = ['normal', 'high', 'low']
# Pre-computed offsets reused across calls
self.text_vocab_size = len(self.uni_prompting.text_tokenizer)
self.codebook_size = int(getattr(self.train_cfg.model.omada, "codebook_size", 8192))
self.speech_codebook = self.codebook_size
self._temp_video_files = []
# ------------------------------------------------------------------
# Text-to-Speech
# ------------------------------------------------------------------
def run_t2s(
self,
text: str,
max_new_tokens: int,
steps: int,
block_length: int,
temperature: float,
cfg_scale: float,
gender_choice: str,
emotion_choice: str,
speed_choice: str,
pitch_choice: str,
) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
if text is None or not text.strip():
return None, "Please provide text to synthesize."
speech_len = int(max_new_tokens)
if speech_len <= 0:
return None, "Speech token length must be positive."
gender = self._resolve_choice(gender_choice, self.genders)
emotion = self._resolve_choice(emotion_choice, self.emotions)
speed = self._resolve_choice(speed_choice, self.speeds)
pitch = self._resolve_choice(pitch_choice, self.pitches)
text = text.strip().upper()
prompt = (
"<|start_header_id|>user<|end_header_id|>\n"
f"{random.choice(T2S_INSTRUCTION)}\n{text}"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
audio_tokens = torch.full(
(1, speech_len),
fill_value=self.mask_token_id,
dtype=torch.long,
device=self.device,
)
input_ids, attention_mask = self.uni_prompting(([prompt], audio_tokens), "t2s_gen")
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
with torch.no_grad():
outputs = self.model.t2s_generate_mmu_like(
input_ids=input_ids,
max_new_tokens=int(max_new_tokens),
steps=int(steps),
block_length=int(block_length),
temperature=float(temperature),
cfg_scale=float(cfg_scale),
mask_token_id=self.mask_token_id,
attention_mask=attention_mask,
uni_prompting=self.uni_prompting,
codebook_size=self.codebook_size,
)
if not outputs:
return None, "Generation produced no speech tokens."
rel = outputs[0]
if isinstance(rel, torch.Tensor):
rel_ids = rel.detach().cpu().tolist()
else:
rel_ids = list(rel)
if not rel_ids:
return None, "Generation produced no speech tokens."
speech_units = "".join(f"<|speech_{sid}|>" for sid in rel_ids)
condition = f"gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}"
wav = self.vq_audio.decode(
speech_units,
condition=condition,
output_wav_file=os.path.join("/tmp", "omada_t2s.wav"),
)
audio = (self.sample_rate, wav.astype(np.float32))
status = f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})."
return audio, status
# ------------------------------------------------------------------
# Speech-to-Text
# ------------------------------------------------------------------
def run_s2t(
self,
audio_path: Optional[str],
steps: int,
block_length: int,
max_new_tokens: int,
remasking: str,
) -> Tuple[str, str]:
if not audio_path:
return "", "Please upload an audio file first."
tokens = self.vq_audio.encode(audio_path).to(self.device)
offset = self.text_vocab_size + self.speech_codebook
tokens = tokens + offset
spt = self.uni_prompting.sptids_dict
audio_block = torch.cat(
[
spt['<|s2t|>'].to(self.device).unsqueeze(0),
spt['<|soa|>'].to(self.device).unsqueeze(0),
tokens.to(self.device),
spt['<|eoa|>'].to(self.device).unsqueeze(0),
],
dim=1,
)
prompt_text = random.choice(S2T_INSTRUCTION)
chat_prompt = (
"<|start_header_id|>user<|end_header_id|>\n"
f"{prompt_text}"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
prompt_tensor = self.uni_prompting.text_tokenizer(
chat_prompt,
return_tensors="pt",
).input_ids.to(self.device)
input_ids = torch.cat([audio_block, prompt_tensor], dim=1)
with torch.no_grad():
output_ids = self.model.mmu_generate(
input_ids,
max_new_tokens=int(max_new_tokens),
steps=int(steps),
block_length=int(block_length),
remasking=str(remasking),
)
decoded = self.uni_prompting.text_tokenizer.batch_decode(
output_ids[:, input_ids.shape[1]:],
skip_special_tokens=True,
)[0]
return decoded.strip(), "Transcription generated successfully."
# ------------------------------------------------------------------
# Video-to-Text
# ------------------------------------------------------------------
def run_v2t(
self,
video_path: Any,
steps: int,
block_length: int,
max_new_tokens: int,
) -> Tuple[str, str]:
resolved_path, converted = self._prepare_video_path(video_path)
if not resolved_path:
return "", "Please upload or record a video file first."
try:
video_tokens = self._extract_video_tokens(resolved_path)
except Exception as exc:
return "", f"Failed to process video: {exc}"
spt = self.uni_prompting.sptids_dict
question = random.choice(V2T_INSTRUCTION)
prompt = (
"<|start_header_id|>user<|end_header_id|>\n"
f"{question}"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
prompt_ids = self.uni_prompting.text_tokenizer(
prompt,
return_tensors="pt",
).input_ids.to(self.device)
input_ids = torch.cat(
[
spt['<|v2t|>'].to(self.device).unsqueeze(0),
spt['<|soi|>'].to(self.device).unsqueeze(0),
video_tokens,
spt['<|eoi|>'].to(self.device).unsqueeze(0),
spt['<|sot|>'].to(self.device).unsqueeze(0),
prompt_ids,
],
dim=1,
).long()
with torch.no_grad():
output_ids = self.model.mmu_generate(
input_ids,
max_new_tokens=int(max_new_tokens),
steps=int(steps),
block_length=int(block_length),
)
decoded = self.uni_prompting.text_tokenizer.batch_decode(
output_ids[:, input_ids.shape[1]:],
skip_special_tokens=True,
)[0]
status_msg = "Video caption generated successfully."
if converted:
status_msg += " (Webcam recording converted to MP4.)"
return decoded.strip(), status_msg
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _resolve_choice(self, choice: Optional[str], options):
if choice is None or choice == 'random':
return random.choice(options)
return choice
def _prepare_video_path(self, video_input: Any) -> Tuple[Optional[str], bool]:
"""Normalize Gradio video inputs (upload/webcam) to an MP4 filepath."""
candidate = None
if isinstance(video_input, str):
candidate = video_input
elif isinstance(video_input, dict):
candidate = (
video_input.get("video")
or video_input.get("name")
or video_input.get("path")
)
elif isinstance(video_input, (list, tuple)) and video_input:
candidate = str(video_input[0])
if not candidate:
return None, False
candidate = str(candidate)
if not self._ensure_file_ready(candidate):
return None, False
if candidate.lower().endswith(".mp4"):
return candidate, False
converted = self._convert_to_mp4(candidate)
if converted:
return converted, True
suffix = Path(candidate).suffix or ".webm"
fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix)
os.close(fd)
try:
shutil.copy2(candidate, tmp_path)
self._temp_video_files.append(tmp_path)
return tmp_path, False
except OSError:
try:
os.remove(tmp_path)
except OSError:
pass
return candidate, False
if candidate.lower().endswith(".mp4"):
return candidate, False
converted = self._convert_to_mp4(candidate)
if converted:
return converted, True
suffix = Path(candidate).suffix or ".webm"
fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix)
os.close(fd)
try:
shutil.copy2(candidate, tmp_path)
self._temp_video_files.append(tmp_path)
return tmp_path, False
except OSError:
try:
os.remove(tmp_path)
except OSError:
pass
return candidate, False
def _ensure_file_ready(self, path: str, retries: int = 8, delay: float = 0.2) -> bool:
"""Ensure the uploaded/recorded file is fully written before processing."""
prev_size = -1
for _ in range(retries):
try:
size = os.path.getsize(path)
except OSError:
size = -1
if size <= 0:
time.sleep(delay)
continue
if size == prev_size:
return True
prev_size = size
time.sleep(delay)
return prev_size > 0
def _convert_to_mp4(self, src_path: str) -> Optional[str]:
"""Convert arbitrary video file to MP4 using OpenCV (drops audio)."""
cap = cv2.VideoCapture(src_path)
if not cap.isOpened():
return None
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
fps = cap.get(cv2.CAP_PROP_FPS)
if width <= 0 or height <= 0:
cap.release()
return None
if not fps or np.isnan(fps) or fps <= 0:
fps = 24.0
fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_", suffix=".mp4")
os.close(fd)
writer = cv2.VideoWriter(
tmp_path,
cv2.VideoWriter_fourcc(*"mp4v"),
float(fps),
(width, height),
)
if not writer.isOpened():
cap.release()
try:
os.remove(tmp_path)
except OSError:
pass
return None
frame_count = 0
try:
while True:
ret, frame = cap.read()
if not ret:
break
writer.write(frame)
frame_count += 1
finally:
cap.release()
writer.release()
if frame_count == 0:
try:
os.remove(tmp_path)
except OSError:
pass
return None
self._temp_video_files.append(tmp_path)
return tmp_path
def _extract_video_tokens(self, video_path: str) -> torch.Tensor:
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames <= 0:
cap.release()
raise RuntimeError(f"No readable frames in {video_path}")
indices = np.linspace(0, total_frames - 1, 8, dtype=int)
frames = []
for idx in range(total_frames):
ret, frame = cap.read()
if idx in indices and ret:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil = Image.fromarray(rgb)
frames.append(image_transform(pil, resolution=self.train_cfg.dataset.preprocessing.resolution))
cap.release()
if len(frames) == 0:
raise RuntimeError("Failed to sample frames for V2T inference.")
video_tensor = torch.stack(frames).to(self.device)
video_tokens = self.vq_image.get_code(video_tensor) + self.text_vocab_size
return video_tokens.long().to(self.device).view(1, -1)
def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optional[int]):
theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
with gr.Blocks(title="OMada Stage1.3 Audio/Video Demo", css=CUSTOM_CSS, theme=theme) as demo:
with gr.Column(elem_classes=["omada-header"]):
if LOGO_DATA_URI:
gr.HTML(
f"<div class='omada-logo-wrapper'><img src=\"{LOGO_DATA_URI}\" alt=\"AIDAS Lab\" class=\"omada-logo-img\" /></div>"
)
elif LOGO_PATH.exists():
gr.Image(
value=str(LOGO_PATH),
show_label=False,
height=140,
interactive=False,
elem_classes=["omada-logo"],
)
gr.Markdown(
"## Omni-modal Diffusion Foundation Model\n### Pretrained Demo",
elem_classes=["omada-title"],
)
gr.Markdown(
"Create speech, transcribe audio, and describe video from a single model. "
"Use the advanced sections when you want tighter control.",
elem_classes=["omada-tagline"],
)
with gr.Row(elem_classes=["omada-layout"], equal_height=False):
with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]):
chatbox = gr.Chatbot(label="Session", height=420, sanitize_html=False)
placeholder_map = {
"Text β†’ Speech": "Type the speech you want to generate...",
"Speech β†’ Text": "Upload audio on the right, then leave notes here if needed.",
"Video β†’ Text": "Upload video on the right, then leave notes here if needed.",
}
chat_input = gr.Textbox(
label="Message",
placeholder=placeholder_map["Text β†’ Speech"],
lines=3,
)
with gr.Row():
send_button = gr.Button("Send", variant="primary")
clear_button = gr.Button("Clear", variant="secondary")
with gr.Column(scale=2, min_width=360, elem_classes=["omada-controls"]):
mode_selector = gr.Dropdown(
["Text β†’ Speech", "Speech β†’ Text", "Video β†’ Text"],
value="Text β†’ Speech",
label="Mode",
)
with gr.Column(visible=True, elem_classes=["omada-mode-panel"]) as t2s_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Text-to-Speech Controls")
with gr.Group(elem_classes=["omada-advanced"]):
gr.Markdown("**Generation**")
with gr.Row():
t2s_max_tokens = gr.Slider(2, 512, value=128, label="Speech token length", step=2)
t2s_steps = gr.Slider(2, 512, value=128, label="Total refinement steps", step=2)
with gr.Row():
t2s_block = gr.Slider(2, 512, value=128, label="Block length", step=2)
t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, label="CFG scale", step=0.1)
t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
with gr.Group(elem_classes=["omada-advanced"]):
gr.Markdown("**Voice styling**")
with gr.Row():
t2s_gender = gr.Dropdown(['random'] + app.genders, value='random', label="Voice gender")
t2s_emotion = gr.Dropdown(['random'] + app.emotions, value='random', label="Emotion")
with gr.Row():
t2s_speed = gr.Dropdown(['random'] + app.speeds, value='random', label="Speaking speed")
t2s_pitch = gr.Dropdown(['random'] + app.pitches, value='random', label="Pitch")
if T2S_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample prompts**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=T2S_EXAMPLES,
inputs=[chat_input],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2t_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Speech-to-Text Controls")
s2t_audio = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
with gr.Row():
s2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2)
s2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2)
s2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2)
s2t_remasking = gr.Dropdown(
choices=["low_confidence", "random"],
value="low_confidence",
label="Remasking strategy",
)
if S2T_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample clips**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=S2T_EXAMPLES,
inputs=[s2t_audio],
examples_per_page=4,
)
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2t_panel:
with gr.Group(elem_classes=["omada-card"]):
gr.Markdown("### Video-to-Text Controls")
v2t_video = gr.Video(
label="Upload or record video",
format=None,
height=256,
sources=["upload", "webcam"],
)
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
with gr.Row():
v2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2)
v2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2)
v2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2)
if V2T_EXAMPLES:
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
gr.Markdown("**Sample videos**")
with gr.Column(elem_classes=["omada-examples"]):
gr.Examples(
examples=V2T_EXAMPLES,
inputs=[v2t_video],
examples_per_page=4,
)
def _toggle_controls(mode: str):
return (
gr.update(visible=mode == "Text β†’ Speech"),
gr.update(visible=mode == "Speech β†’ Text"),
gr.update(visible=mode == "Video β†’ Text"),
gr.update(placeholder=placeholder_map.get(mode, chat_input.placeholder)),
)
mode_selector.change(
_toggle_controls,
inputs=[mode_selector],
outputs=[t2s_panel, s2t_panel, v2t_panel, chat_input],
)
def _chat_handler(
history,
message,
mode,
audio_path,
video_path,
t2s_max_tokens,
t2s_steps,
t2s_block,
t2s_temperature,
t2s_cfg,
t2s_gender,
t2s_emotion,
t2s_speed,
t2s_pitch,
s2t_steps,
s2t_block,
s2t_max_tokens,
s2t_remasking,
v2t_steps,
v2t_block,
v2t_max_tokens,
):
history = history or []
message = (message or "").strip()
response = ""
if mode == "Text β†’ Speech":
if not message:
status = "Please type some text for speech synthesis."
response = _render_text_message(status, "")
else:
audio_result, status = app.run_t2s(
message,
t2s_max_tokens,
t2s_steps,
t2s_block,
t2s_temperature,
t2s_cfg,
t2s_gender,
t2s_emotion,
t2s_speed,
t2s_pitch,
)
response = _render_audio_message(status, audio_result)
display_user_raw = message or "[Speech request]"
elif mode == "Speech β†’ Text":
if not audio_path:
status = "Please upload or record an audio clip first."
response = _render_text_message(status, "")
else:
transcript, status = app.run_s2t(
audio_path,
s2t_steps,
s2t_block,
s2t_max_tokens,
s2t_remasking,
)
response = _render_text_message(status, transcript)
display_user_raw = message or "[Audio transcription request]"
else: # Video β†’ Text
if not video_path:
status = "Please upload or record a video first."
response = _render_text_message(status, "")
else:
caption, status = app.run_v2t(
video_path,
v2t_steps,
v2t_block,
v2t_max_tokens,
)
response = _render_text_message(status, caption)
display_user_raw = message or "[Video caption request]"
display_user = _format_user_message(display_user_raw)
history = history + [(display_user, response)]
return history, ""
submit_inputs = [
chatbox,
chat_input,
mode_selector,
s2t_audio,
v2t_video,
t2s_max_tokens,
t2s_steps,
t2s_block,
t2s_temperature,
t2s_cfg,
t2s_gender,
t2s_emotion,
t2s_speed,
t2s_pitch,
s2t_steps,
s2t_block,
s2t_max_tokens,
s2t_remasking,
v2t_steps,
v2t_block,
v2t_max_tokens,
]
submit_outputs = [chatbox, chat_input]
chat_input.submit(_chat_handler, inputs=submit_inputs, outputs=submit_outputs)
send_button.click(_chat_handler, inputs=submit_inputs, outputs=submit_outputs)
def _clear_session():
return (
[],
"",
gr.update(value=None),
gr.update(value=None),
)
clear_button.click(
_clear_session,
inputs=None,
outputs=[chatbox, chat_input, s2t_audio, v2t_video],
)
demo.launch(
share=share,
server_name=server_name,
server_port=server_port,
)
def parse_args():
parser = argparse.ArgumentParser(description="OMada Gradio demo for audio/video tasks")
parser.add_argument("--train-config", required=True, help="Path to the training YAML used to build tokenizer + VQ modules")
parser.add_argument("--checkpoint", required=True, help="Path to an `unwrapped_model` directory")
parser.add_argument("--device", default=None, help="Override device (e.g. cuda:0). Defaults to CUDA if available")
parser.add_argument("--share", action="store_true", help="Enable public Gradio share link")
parser.add_argument("--server-name", default="0.0.0.0", help="Host address for Blocks.launch")
parser.add_argument("--server-port", type=int, default=None, help="Port for Blocks.launch")
return parser.parse_args()
def main():
args = parse_args()
app = OmadaDemo(args.train_config, args.checkpoint, args.device)
build_demo(app, args.share, args.server_name, args.server_port)
if __name__ == "__main__":
main()