#!/usr/bin/env python3 """ Gradio demo for OMada Stage 1.3 checkpoints covering: * Text-to-Speech (T2S) * Speech-to-Text (S2T) * Video-to-Text (V2T) The implementation wraps the existing CLI inference helpers so that a single checkpoint directory (…/checkpoint-XXXX/unwrapped_model) can be previewed interactively. Usage: python MMaDA/inference/gradio_multimodal_demo2.py --train-config MMaDA/configs/omada_pretraining_stage1-3.yaml --checkpoint ../ckpt/checkpoint-315000/unwrapped_model/ --share If you need remote access, pass `--share` which simply forwards the flag to `gradio.Blocks.launch`. """ import argparse import base64 import html import io import os import random import sys import tempfile import wave from pathlib import Path import shutil import time from typing import Any, Optional, Tuple import numpy as np CUSTOM_CSS = """ :root { --omada-primary: #1e3a8a; --omada-accent: #1d4ed8; --omada-surface: #f3f4f6; --omada-surface-alt: #ffffff; --omada-border: #d0d7e5; --omada-text-primary: #111827; --omada-text-muted: #374151; color-scheme: light; } html, body, body.dark, html.dark { background: var(--omada-surface) !important; color: var(--omada-text-primary) !important; } .gradio-container { background: var(--omada-surface); color: var(--omada-text-primary); } .omada-page-heading { margin-bottom: 0; color: var(--omada-text-primary); } .omada-tab-intro p { font-size: 0.95rem; color: var(--omada-text-primary); margin-top: 0; opacity: 0.9; } .omada-card { background: var(--omada-surface-alt); border-radius: 16px !important; padding: 18px !important; box-shadow: none; border: 1px solid var(--omada-border); color: var(--omada-text-primary); } .omada-card .gradio-slider .wrap-inner { gap: 6px; } .omada-card .gradio-slider input[type=range]::-webkit-slider-thumb { background: var(--omada-primary); } .gradio-slider input[type=range]::-webkit-slider-runnable-track { background: rgba(14, 33, 80, 0.2); } .omada-section-title p { text-transform: uppercase; font-size: 0.78rem; letter-spacing: 0.14em; color: rgba(30, 58, 138, 0.85); margin: 0 0 12px 0; } .omada-output .gradio-audio, .omada-output .gradio-textbox { margin-top: 12px; } .gradio-textbox, .gradio-dropdown, .gradio-slider, .gradio-audio { color: var(--omada-text-primary) !important; } .gradio-dropdown .wrap .label, .gradio-textbox label, .gradio-slider label { color: var(--omada-text-primary); } .gradio-dropdown .single-select, .gradio-textbox textarea { background: #ffffff !important; border: 1px solid var(--omada-border) !important; color: var(--omada-text-primary) !important; } .gradio-dropdown .single-select select, .gradio-textbox textarea { color: var(--omada-text-primary) !important; } .gradio-textbox textarea::placeholder { color: rgba(148, 163, 184, 0.65); } .gradio-dropdown, .gradio-textbox, .gradio-audio, .gradio-video, .gradio-slider { background: #ffffff !important; border: 1px solid var(--omada-border) !important; border-radius: 12px !important; } .full-width-button button { width: 100%; background: var(--omada-primary) !important; color: white !important; border: none !important; font-weight: 600; transition: transform 0.2s ease, box-shadow 0.2s ease; box-shadow: 0 12px 30px -12px rgba(79, 70, 229, 0.65); } .full-width-button button:hover { transform: translateY(-1px); box-shadow: 0 18px 34px -14px rgba(79, 70, 229, 0.75); } .omada-advanced .gr-accordion-header { font-size: 0.85rem; letter-spacing: 0.05em; color: var(--omada-text-muted); } .omada-advanced .gr-accordion { border: 1px solid var(--omada-border); border-radius: 12px; background: #ffffff; } .gradio-tabs { background: transparent; } .gradio-tabs ul.tab-list { background: transparent; border-bottom: 1px solid var(--omada-border); } .gradio-tabs button { color: var(--omada-text-primary); } .gradio-tabs button.selected { color: var(--omada-text-primary); background: rgba(14, 33, 80, 0.1); border-bottom: 2px solid var(--omada-primary); } .gradio-container .label { background: rgba(30, 58, 138, 0.1) !important; color: var(--omada-primary) !important; border: 1px solid rgba(30, 58, 138, 0.25) !important; border-radius: 999px !important; padding: 4px 12px !important; } .gradio-button.primary { background: var(--omada-primary) !important; color: #ffffff !important; border: 1px solid var(--omada-primary) !important; } .gradio-accordion { box-shadow: none; } .omada-layout { gap: 20px !important; } .omada-chat-column { gap: 12px !important; } .omada-chat-column .gradio-chatbot { border-radius: 16px; box-shadow: none; border: 1px solid var(--omada-border); background: #ffffff; } .omada-controls { gap: 16px !important; } .omada-mode-panel { display: flex; flex-direction: column; gap: 16px !important; } .omada-examples-card { padding-top: 10px !important; } .omada-output-panel .gradio-audio, .omada-output-panel .gradio-textbox { margin-top: 8px; } .omada-response-container { display: flex; flex-direction: column; gap: 10px; } .omada-response-status { margin: 0; font-weight: 600; font-size: 0.95rem; color: var(--omada-text-primary); } .omada-response-block, .omada-audio-block { background: rgba(30, 58, 138, 0.05); border-radius: 12px; padding: 12px 14px; color: var(--omada-text-primary); white-space: pre-wrap; word-break: break-word; } .omada-audio-block audio { width: 100%; } .omada-header { display: flex; flex-direction: column; align-items: center; justify-content: center; gap: 18px !important; margin-bottom: 18px; text-align: center; } .omada-header .gradio-image { background: transparent !important; border: none !important; } .omada-header img { object-fit: contain; display: block; } .omada-logo { max-width: 180px; padding: 0 !important; } .omada-examples { margin-top: 8px; padding-top: 4px; } .omada-logo .gradio-image, .omada-logo .gradio-image > div, .omada-logo .gradio-image .container { background: transparent !important; border: none !important; box-shadow: none !important; padding: 0 !important; display: flex; justify-content: center; align-items: center; } .omada-logo img { width: 100%; height: auto; } .omada-logo button { display: none !important; } .gradio-container .gradio-component, .gradio-container .gradio-panel, .gradio-container .gradio-box { background: transparent !important; color: var(--omada-text-primary); } .dark .gradio-container, .dark .gradio-interface, .dark .gradio-container * { background-color: inherit !important; color: var(--omada-text-primary) !important; } .dark .gradio-container .gradio-chatbot, .dark .gradio-container .gradio-dropdown, .dark .gradio-container .gradio-textbox, .dark .gradio-container .gradio-audio, .dark .gradio-container .gradio-video, .dark .gradio-container .gradio-slider, .dark .gradio-container .gradio-accordion, .dark .gradio-container .gradio-panel, .dark .gradio-container .gradio-box { background: #ffffff !important; border-color: var(--omada-border) !important; } .omada-title h2 { font-size: 2.4rem; font-weight: 700; color: var(--omada-text-primary); margin: 0; } .omada-title h3 { font-size: 1.25rem; font-weight: 600; letter-spacing: 0.1em; text-transform: uppercase; color: var(--omada-text-muted); margin: 6px 0 0; } .omada-tagline p { color: var(--omada-text-primary); font-size: 1rem; margin: 0; opacity: 0.9; } .gradio-container .prose :where(h1, h2, h3, h4, h5, h6) { color: var(--omada-text-primary) !important; } .gradio-container .prose :where(p, li) { color: var(--omada-text-muted) !important; } .gradio-container label, .gradio-container span, .gradio-container button { color: var(--omada-text-primary); } .gradio-container .dark { background: #ffffff !important; color: var(--omada-text-primary) !important; } .omada-logo-img { max-width: 250px; width: 100%; height: auto; display: block; margin: 0 auto; } .omada-logo-wrapper { display: flex; justify-content: center; align-items: center; } """ DEMO_ROOT = Path(__file__).resolve().parent / "demo" LOGO_PATH = DEMO_ROOT / "logo.png" T2S_TEXT_PATH = DEMO_ROOT / "t2s" / "text.txt" def _load_logo_data() -> Optional[str]: if not LOGO_PATH.exists(): return None try: import base64 except ImportError: return str(LOGO_PATH) try: encoded = base64.b64encode(LOGO_PATH.read_bytes()).decode("utf-8") except OSError: return str(LOGO_PATH) return f"data:image/png;base64,{encoded}" def _load_t2s_examples(): if not T2S_TEXT_PATH.exists(): return [] lines = [ line.strip() for line in T2S_TEXT_PATH.read_text(encoding="utf-8").splitlines() if line.strip() ] return [[line] for line in lines] def _load_media_examples(subdir: str, suffixes): target_dir = DEMO_ROOT / subdir if not target_dir.exists(): return [] examples = [] for path in sorted(target_dir.iterdir()): if path.is_file() and path.suffix.lower() in suffixes: examples.append([str(path)]) return examples T2S_EXAMPLES = _load_t2s_examples() S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"}) V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) LOGO_DATA_URI = _load_logo_data() def _render_response(status: str, body_html: str = "") -> str: safe_status = html.escape(status or "") parts = [] if safe_status: parts.append(f"

{safe_status}

") if body_html: parts.append(body_html) content = "".join(parts) return f"
{content}
" def _render_text_message(status: str, content: Optional[str]) -> str: content = (content or "").strip() if not content: return _render_response(status) safe_content = html.escape(content).replace("\n", "
") body = f"
{safe_content}
" return _render_response(status, body) def _render_audio_message(status: str, audio: Optional[Tuple[int, np.ndarray]]) -> str: """Render an inline HTML audio player for chat responses.""" if not audio: return _render_response(status) sample_rate, data = audio if data is None: return _render_response(status) waveform = np.asarray(data, dtype=np.float32) if waveform.size == 0: return _render_response(status) if waveform.ndim == 1: waveform = waveform[:, None] channels = waveform.shape[1] clipped = np.clip(waveform, -1.0, 1.0) pcm16 = (clipped * 32767.0).astype(np.int16) buffer = io.BytesIO() with wave.open(buffer, "wb") as wav_writer: wav_writer.setnchannels(channels) wav_writer.setsampwidth(2) # 16-bit PCM wav_writer.setframerate(int(sample_rate)) wav_writer.writeframes(pcm16.tobytes()) encoded = base64.b64encode(buffer.getvalue()).decode("ascii") audio_tag = ( "
" "" "
" ) return _render_response(status, audio_tag) def _format_user_message(message: str) -> str: clean = html.escape(message or "") return clean.replace("\n", "
") # Ensure project modules (models, training, inference.common, …) are importable when the # script is launched directly via `python MMaDA/inference/gradio_multimodal_demo.py`. PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) import cv2 import gradio as gr import numpy as np import torch from PIL import Image from inference.common import ( build_uni_prompting, get_vq_model_audio, get_vq_model_image, load_omada_from_checkpoint, load_train_config, ) from models import get_mask_schedule from training.data import S2T_INSTRUCTION, T2S_INSTRUCTION, V2T_INSTRUCTION from training.utils import image_transform def _resolve_noise_schedule(train_cfg) -> callable: """Return the diffusion noise schedule used for T2S sampling.""" schedule_cfg = getattr(train_cfg, "mask_schedule", None) if schedule_cfg and hasattr(schedule_cfg, "schedule"): schedule_name = schedule_cfg.schedule schedule_kwargs = schedule_cfg.get("params", {}) return get_mask_schedule(schedule_name, **schedule_kwargs) schedule_name = train_cfg.training.get("mask_schedule", "cosine") return get_mask_schedule(schedule_name) class OmadaDemo: """Lightweight container that loads all inference assets once.""" def __init__(self, train_config: str, checkpoint: str, device: Optional[str] = None): ckpt_path = Path(checkpoint) if ckpt_path.name != "unwrapped_model": raise ValueError( "`--checkpoint` must point to an `unwrapped_model` directory. " f"Received: {checkpoint}" ) self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) self.train_cfg = load_train_config(train_config) self.uni_prompting, _ = build_uni_prompting(self.train_cfg) # Core models self.model = load_omada_from_checkpoint(str(ckpt_path), self.device) self.vq_audio = get_vq_model_audio(self.train_cfg, self.device) self.vq_image = get_vq_model_image(self.train_cfg, self.device) self.model.eval() self.vq_audio.eval() self.vq_image.eval() # Cached constants self.mask_token_id = int(self.model.config.mask_token_id) self.noise_schedule = _resolve_noise_schedule(self.train_cfg) self.sample_rate = int(getattr(self.vq_audio.u2s_config.data, "sampling_rate", 22050)) self.genders = ['female', 'male'] self.emotions = ['angry', 'happy', 'neutral', 'sad'] self.speeds = ['normal', 'fast', 'slow'] self.pitches = ['normal', 'high', 'low'] # Pre-computed offsets reused across calls self.text_vocab_size = len(self.uni_prompting.text_tokenizer) self.codebook_size = int(getattr(self.train_cfg.model.omada, "codebook_size", 8192)) self.speech_codebook = self.codebook_size self._temp_video_files = [] # ------------------------------------------------------------------ # Text-to-Speech # ------------------------------------------------------------------ def run_t2s( self, text: str, max_new_tokens: int, steps: int, block_length: int, temperature: float, cfg_scale: float, gender_choice: str, emotion_choice: str, speed_choice: str, pitch_choice: str, ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: if text is None or not text.strip(): return None, "Please provide text to synthesize." speech_len = int(max_new_tokens) if speech_len <= 0: return None, "Speech token length must be positive." gender = self._resolve_choice(gender_choice, self.genders) emotion = self._resolve_choice(emotion_choice, self.emotions) speed = self._resolve_choice(speed_choice, self.speeds) pitch = self._resolve_choice(pitch_choice, self.pitches) text = text.strip().upper() prompt = ( "<|start_header_id|>user<|end_header_id|>\n" f"{random.choice(T2S_INSTRUCTION)}\n{text}" "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" ) audio_tokens = torch.full( (1, speech_len), fill_value=self.mask_token_id, dtype=torch.long, device=self.device, ) input_ids, attention_mask = self.uni_prompting(([prompt], audio_tokens), "t2s_gen") input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) with torch.no_grad(): outputs = self.model.t2s_generate_mmu_like( input_ids=input_ids, max_new_tokens=int(max_new_tokens), steps=int(steps), block_length=int(block_length), temperature=float(temperature), cfg_scale=float(cfg_scale), mask_token_id=self.mask_token_id, attention_mask=attention_mask, uni_prompting=self.uni_prompting, codebook_size=self.codebook_size, ) if not outputs: return None, "Generation produced no speech tokens." rel = outputs[0] if isinstance(rel, torch.Tensor): rel_ids = rel.detach().cpu().tolist() else: rel_ids = list(rel) if not rel_ids: return None, "Generation produced no speech tokens." speech_units = "".join(f"<|speech_{sid}|>" for sid in rel_ids) condition = f"gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}" wav = self.vq_audio.decode( speech_units, condition=condition, output_wav_file=os.path.join("/tmp", "omada_t2s.wav"), ) audio = (self.sample_rate, wav.astype(np.float32)) status = f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})." return audio, status # ------------------------------------------------------------------ # Speech-to-Text # ------------------------------------------------------------------ def run_s2t( self, audio_path: Optional[str], steps: int, block_length: int, max_new_tokens: int, remasking: str, ) -> Tuple[str, str]: if not audio_path: return "", "Please upload an audio file first." tokens = self.vq_audio.encode(audio_path).to(self.device) offset = self.text_vocab_size + self.speech_codebook tokens = tokens + offset spt = self.uni_prompting.sptids_dict audio_block = torch.cat( [ spt['<|s2t|>'].to(self.device).unsqueeze(0), spt['<|soa|>'].to(self.device).unsqueeze(0), tokens.to(self.device), spt['<|eoa|>'].to(self.device).unsqueeze(0), ], dim=1, ) prompt_text = random.choice(S2T_INSTRUCTION) chat_prompt = ( "<|start_header_id|>user<|end_header_id|>\n" f"{prompt_text}" "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" ) prompt_tensor = self.uni_prompting.text_tokenizer( chat_prompt, return_tensors="pt", ).input_ids.to(self.device) input_ids = torch.cat([audio_block, prompt_tensor], dim=1) with torch.no_grad(): output_ids = self.model.mmu_generate( input_ids, max_new_tokens=int(max_new_tokens), steps=int(steps), block_length=int(block_length), remasking=str(remasking), ) decoded = self.uni_prompting.text_tokenizer.batch_decode( output_ids[:, input_ids.shape[1]:], skip_special_tokens=True, )[0] return decoded.strip(), "Transcription generated successfully." # ------------------------------------------------------------------ # Video-to-Text # ------------------------------------------------------------------ def run_v2t( self, video_path: Any, steps: int, block_length: int, max_new_tokens: int, ) -> Tuple[str, str]: resolved_path, converted = self._prepare_video_path(video_path) if not resolved_path: return "", "Please upload or record a video file first." try: video_tokens = self._extract_video_tokens(resolved_path) except Exception as exc: return "", f"Failed to process video: {exc}" spt = self.uni_prompting.sptids_dict question = random.choice(V2T_INSTRUCTION) prompt = ( "<|start_header_id|>user<|end_header_id|>\n" f"{question}" "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" ) prompt_ids = self.uni_prompting.text_tokenizer( prompt, return_tensors="pt", ).input_ids.to(self.device) input_ids = torch.cat( [ spt['<|v2t|>'].to(self.device).unsqueeze(0), spt['<|soi|>'].to(self.device).unsqueeze(0), video_tokens, spt['<|eoi|>'].to(self.device).unsqueeze(0), spt['<|sot|>'].to(self.device).unsqueeze(0), prompt_ids, ], dim=1, ).long() with torch.no_grad(): output_ids = self.model.mmu_generate( input_ids, max_new_tokens=int(max_new_tokens), steps=int(steps), block_length=int(block_length), ) decoded = self.uni_prompting.text_tokenizer.batch_decode( output_ids[:, input_ids.shape[1]:], skip_special_tokens=True, )[0] status_msg = "Video caption generated successfully." if converted: status_msg += " (Webcam recording converted to MP4.)" return decoded.strip(), status_msg # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _resolve_choice(self, choice: Optional[str], options): if choice is None or choice == 'random': return random.choice(options) return choice def _prepare_video_path(self, video_input: Any) -> Tuple[Optional[str], bool]: """Normalize Gradio video inputs (upload/webcam) to an MP4 filepath.""" candidate = None if isinstance(video_input, str): candidate = video_input elif isinstance(video_input, dict): candidate = ( video_input.get("video") or video_input.get("name") or video_input.get("path") ) elif isinstance(video_input, (list, tuple)) and video_input: candidate = str(video_input[0]) if not candidate: return None, False candidate = str(candidate) if not self._ensure_file_ready(candidate): return None, False if candidate.lower().endswith(".mp4"): return candidate, False converted = self._convert_to_mp4(candidate) if converted: return converted, True suffix = Path(candidate).suffix or ".webm" fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) os.close(fd) try: shutil.copy2(candidate, tmp_path) self._temp_video_files.append(tmp_path) return tmp_path, False except OSError: try: os.remove(tmp_path) except OSError: pass return candidate, False if candidate.lower().endswith(".mp4"): return candidate, False converted = self._convert_to_mp4(candidate) if converted: return converted, True suffix = Path(candidate).suffix or ".webm" fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) os.close(fd) try: shutil.copy2(candidate, tmp_path) self._temp_video_files.append(tmp_path) return tmp_path, False except OSError: try: os.remove(tmp_path) except OSError: pass return candidate, False def _ensure_file_ready(self, path: str, retries: int = 8, delay: float = 0.2) -> bool: """Ensure the uploaded/recorded file is fully written before processing.""" prev_size = -1 for _ in range(retries): try: size = os.path.getsize(path) except OSError: size = -1 if size <= 0: time.sleep(delay) continue if size == prev_size: return True prev_size = size time.sleep(delay) return prev_size > 0 def _convert_to_mp4(self, src_path: str) -> Optional[str]: """Convert arbitrary video file to MP4 using OpenCV (drops audio).""" cap = cv2.VideoCapture(src_path) if not cap.isOpened(): return None width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) fps = cap.get(cv2.CAP_PROP_FPS) if width <= 0 or height <= 0: cap.release() return None if not fps or np.isnan(fps) or fps <= 0: fps = 24.0 fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_", suffix=".mp4") os.close(fd) writer = cv2.VideoWriter( tmp_path, cv2.VideoWriter_fourcc(*"mp4v"), float(fps), (width, height), ) if not writer.isOpened(): cap.release() try: os.remove(tmp_path) except OSError: pass return None frame_count = 0 try: while True: ret, frame = cap.read() if not ret: break writer.write(frame) frame_count += 1 finally: cap.release() writer.release() if frame_count == 0: try: os.remove(tmp_path) except OSError: pass return None self._temp_video_files.append(tmp_path) return tmp_path def _extract_video_tokens(self, video_path: str) -> torch.Tensor: cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames <= 0: cap.release() raise RuntimeError(f"No readable frames in {video_path}") indices = np.linspace(0, total_frames - 1, 8, dtype=int) frames = [] for idx in range(total_frames): ret, frame = cap.read() if idx in indices and ret: rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil = Image.fromarray(rgb) frames.append(image_transform(pil, resolution=self.train_cfg.dataset.preprocessing.resolution)) cap.release() if len(frames) == 0: raise RuntimeError("Failed to sample frames for V2T inference.") video_tensor = torch.stack(frames).to(self.device) video_tokens = self.vq_image.get_code(video_tensor) + self.text_vocab_size return video_tokens.long().to(self.device).view(1, -1) def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optional[int]): theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray") with gr.Blocks(title="OMada Stage1.3 Audio/Video Demo", css=CUSTOM_CSS, theme=theme) as demo: with gr.Column(elem_classes=["omada-header"]): if LOGO_DATA_URI: gr.HTML( f"
\"AIDAS
" ) elif LOGO_PATH.exists(): gr.Image( value=str(LOGO_PATH), show_label=False, height=140, interactive=False, elem_classes=["omada-logo"], ) gr.Markdown( "## Omni-modal Diffusion Foundation Model\n### Pretrained Demo", elem_classes=["omada-title"], ) gr.Markdown( "Create speech, transcribe audio, and describe video from a single model. " "Use the advanced sections when you want tighter control.", elem_classes=["omada-tagline"], ) with gr.Row(elem_classes=["omada-layout"], equal_height=False): with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]): chatbox = gr.Chatbot(label="Session", height=420, sanitize_html=False) placeholder_map = { "Text → Speech": "Type the speech you want to generate...", "Speech → Text": "Upload audio on the right, then leave notes here if needed.", "Video → Text": "Upload video on the right, then leave notes here if needed.", } chat_input = gr.Textbox( label="Message", placeholder=placeholder_map["Text → Speech"], lines=3, ) with gr.Row(): send_button = gr.Button("Send", variant="primary") clear_button = gr.Button("Clear", variant="secondary") with gr.Column(scale=2, min_width=360, elem_classes=["omada-controls"]): mode_selector = gr.Dropdown( ["Text → Speech", "Speech → Text", "Video → Text"], value="Text → Speech", label="Mode", ) with gr.Column(visible=True, elem_classes=["omada-mode-panel"]) as t2s_panel: with gr.Group(elem_classes=["omada-card"]): gr.Markdown("### Text-to-Speech Controls") with gr.Group(elem_classes=["omada-advanced"]): gr.Markdown("**Generation**") with gr.Row(): t2s_max_tokens = gr.Slider(2, 512, value=128, label="Speech token length", step=2) t2s_steps = gr.Slider(2, 512, value=128, label="Total refinement steps", step=2) with gr.Row(): t2s_block = gr.Slider(2, 512, value=128, label="Block length", step=2) t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, label="CFG scale", step=0.1) t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) with gr.Group(elem_classes=["omada-advanced"]): gr.Markdown("**Voice styling**") with gr.Row(): t2s_gender = gr.Dropdown(['random'] + app.genders, value='random', label="Voice gender") t2s_emotion = gr.Dropdown(['random'] + app.emotions, value='random', label="Emotion") with gr.Row(): t2s_speed = gr.Dropdown(['random'] + app.speeds, value='random', label="Speaking speed") t2s_pitch = gr.Dropdown(['random'] + app.pitches, value='random', label="Pitch") if T2S_EXAMPLES: with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): gr.Markdown("**Sample prompts**") with gr.Column(elem_classes=["omada-examples"]): gr.Examples( examples=T2S_EXAMPLES, inputs=[chat_input], examples_per_page=4, ) with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2t_panel: with gr.Group(elem_classes=["omada-card"]): gr.Markdown("### Speech-to-Text Controls") s2t_audio = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"]) with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): with gr.Row(): s2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2) s2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2) s2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2) s2t_remasking = gr.Dropdown( choices=["low_confidence", "random"], value="low_confidence", label="Remasking strategy", ) if S2T_EXAMPLES: with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): gr.Markdown("**Sample clips**") with gr.Column(elem_classes=["omada-examples"]): gr.Examples( examples=S2T_EXAMPLES, inputs=[s2t_audio], examples_per_page=4, ) with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2t_panel: with gr.Group(elem_classes=["omada-card"]): gr.Markdown("### Video-to-Text Controls") v2t_video = gr.Video( label="Upload or record video", format=None, height=256, sources=["upload", "webcam"], ) with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): with gr.Row(): v2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2) v2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2) v2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2) if V2T_EXAMPLES: with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): gr.Markdown("**Sample videos**") with gr.Column(elem_classes=["omada-examples"]): gr.Examples( examples=V2T_EXAMPLES, inputs=[v2t_video], examples_per_page=4, ) def _toggle_controls(mode: str): return ( gr.update(visible=mode == "Text → Speech"), gr.update(visible=mode == "Speech → Text"), gr.update(visible=mode == "Video → Text"), gr.update(placeholder=placeholder_map.get(mode, chat_input.placeholder)), ) mode_selector.change( _toggle_controls, inputs=[mode_selector], outputs=[t2s_panel, s2t_panel, v2t_panel, chat_input], ) def _chat_handler( history, message, mode, audio_path, video_path, t2s_max_tokens, t2s_steps, t2s_block, t2s_temperature, t2s_cfg, t2s_gender, t2s_emotion, t2s_speed, t2s_pitch, s2t_steps, s2t_block, s2t_max_tokens, s2t_remasking, v2t_steps, v2t_block, v2t_max_tokens, ): history = history or [] message = (message or "").strip() response = "" if mode == "Text → Speech": if not message: status = "Please type some text for speech synthesis." response = _render_text_message(status, "") else: audio_result, status = app.run_t2s( message, t2s_max_tokens, t2s_steps, t2s_block, t2s_temperature, t2s_cfg, t2s_gender, t2s_emotion, t2s_speed, t2s_pitch, ) response = _render_audio_message(status, audio_result) display_user_raw = message or "[Speech request]" elif mode == "Speech → Text": if not audio_path: status = "Please upload or record an audio clip first." response = _render_text_message(status, "") else: transcript, status = app.run_s2t( audio_path, s2t_steps, s2t_block, s2t_max_tokens, s2t_remasking, ) response = _render_text_message(status, transcript) display_user_raw = message or "[Audio transcription request]" else: # Video → Text if not video_path: status = "Please upload or record a video first." response = _render_text_message(status, "") else: caption, status = app.run_v2t( video_path, v2t_steps, v2t_block, v2t_max_tokens, ) response = _render_text_message(status, caption) display_user_raw = message or "[Video caption request]" display_user = _format_user_message(display_user_raw) history = history + [(display_user, response)] return history, "" submit_inputs = [ chatbox, chat_input, mode_selector, s2t_audio, v2t_video, t2s_max_tokens, t2s_steps, t2s_block, t2s_temperature, t2s_cfg, t2s_gender, t2s_emotion, t2s_speed, t2s_pitch, s2t_steps, s2t_block, s2t_max_tokens, s2t_remasking, v2t_steps, v2t_block, v2t_max_tokens, ] submit_outputs = [chatbox, chat_input] chat_input.submit(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) send_button.click(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) def _clear_session(): return ( [], "", gr.update(value=None), gr.update(value=None), ) clear_button.click( _clear_session, inputs=None, outputs=[chatbox, chat_input, s2t_audio, v2t_video], ) demo.launch( share=share, server_name=server_name, server_port=server_port, ) def parse_args(): parser = argparse.ArgumentParser(description="OMada Gradio demo for audio/video tasks") parser.add_argument("--train-config", required=True, help="Path to the training YAML used to build tokenizer + VQ modules") parser.add_argument("--checkpoint", required=True, help="Path to an `unwrapped_model` directory") parser.add_argument("--device", default=None, help="Override device (e.g. cuda:0). Defaults to CUDA if available") parser.add_argument("--share", action="store_true", help="Enable public Gradio share link") parser.add_argument("--server-name", default="0.0.0.0", help="Host address for Blocks.launch") parser.add_argument("--server-port", type=int, default=None, help="Port for Blocks.launch") return parser.parse_args() def main(): args = parse_args() app = OmadaDemo(args.train_config, args.checkpoint, args.device) build_demo(app, args.share, args.server_name, args.server_port) if __name__ == "__main__": main()