Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| Gradio demo for OMada Stage 1.3 checkpoints covering: | |
| * Text-to-Speech (T2S) | |
| * Speech-to-Text (S2T) | |
| * Video-to-Text (V2T) | |
| The implementation wraps the existing CLI inference helpers so that a single | |
| checkpoint directory (β¦/checkpoint-XXXX/unwrapped_model) can be previewed | |
| interactively. Usage: | |
| python MMaDA/inference/gradio_multimodal_demo2.py --train-config MMaDA/configs/omada_pretraining_stage1-3.yaml --checkpoint ../ckpt/checkpoint-315000/unwrapped_model/ --share | |
| If you need remote access, pass `--share` which simply forwards the flag to | |
| `gradio.Blocks.launch`. | |
| """ | |
| import argparse | |
| import base64 | |
| import html | |
| import io | |
| import os | |
| import random | |
| import sys | |
| import tempfile | |
| import wave | |
| from pathlib import Path | |
| import shutil | |
| import time | |
| from typing import Any, Optional, Tuple | |
| import numpy as np | |
| CUSTOM_CSS = """ | |
| :root { | |
| --omada-primary: #1e3a8a; | |
| --omada-accent: #1d4ed8; | |
| --omada-surface: #f3f4f6; | |
| --omada-surface-alt: #ffffff; | |
| --omada-border: #d0d7e5; | |
| --omada-text-primary: #111827; | |
| --omada-text-muted: #374151; | |
| color-scheme: light; | |
| } | |
| html, body, body.dark, html.dark { | |
| background: var(--omada-surface) !important; | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-container { | |
| background: var(--omada-surface); | |
| color: var(--omada-text-primary); | |
| } | |
| .omada-page-heading { | |
| margin-bottom: 0; | |
| color: var(--omada-text-primary); | |
| } | |
| .omada-tab-intro p { | |
| font-size: 0.95rem; | |
| color: var(--omada-text-primary); | |
| margin-top: 0; | |
| opacity: 0.9; | |
| } | |
| .omada-card { | |
| background: var(--omada-surface-alt); | |
| border-radius: 16px !important; | |
| padding: 18px !important; | |
| box-shadow: none; | |
| border: 1px solid var(--omada-border); | |
| color: var(--omada-text-primary); | |
| } | |
| .omada-card .gradio-slider .wrap-inner { | |
| gap: 6px; | |
| } | |
| .omada-card .gradio-slider input[type=range]::-webkit-slider-thumb { | |
| background: var(--omada-primary); | |
| } | |
| .gradio-slider input[type=range]::-webkit-slider-runnable-track { | |
| background: rgba(14, 33, 80, 0.2); | |
| } | |
| .omada-section-title p { | |
| text-transform: uppercase; | |
| font-size: 0.78rem; | |
| letter-spacing: 0.14em; | |
| color: rgba(30, 58, 138, 0.85); | |
| margin: 0 0 12px 0; | |
| } | |
| .omada-output .gradio-audio, .omada-output .gradio-textbox { | |
| margin-top: 12px; | |
| } | |
| .gradio-textbox, .gradio-dropdown, .gradio-slider, .gradio-audio { | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-dropdown .wrap .label, .gradio-textbox label, .gradio-slider label { | |
| color: var(--omada-text-primary); | |
| } | |
| .gradio-dropdown .single-select, .gradio-textbox textarea { | |
| background: #ffffff !important; | |
| border: 1px solid var(--omada-border) !important; | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-dropdown .single-select select, .gradio-textbox textarea { | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-textbox textarea::placeholder { | |
| color: rgba(148, 163, 184, 0.65); | |
| } | |
| .gradio-dropdown, .gradio-textbox, .gradio-audio, .gradio-video, .gradio-slider { | |
| background: #ffffff !important; | |
| border: 1px solid var(--omada-border) !important; | |
| border-radius: 12px !important; | |
| } | |
| .full-width-button button { | |
| width: 100%; | |
| background: var(--omada-primary) !important; | |
| color: white !important; | |
| border: none !important; | |
| font-weight: 600; | |
| transition: transform 0.2s ease, box-shadow 0.2s ease; | |
| box-shadow: 0 12px 30px -12px rgba(79, 70, 229, 0.65); | |
| } | |
| .full-width-button button:hover { | |
| transform: translateY(-1px); | |
| box-shadow: 0 18px 34px -14px rgba(79, 70, 229, 0.75); | |
| } | |
| .omada-advanced .gr-accordion-header { | |
| font-size: 0.85rem; | |
| letter-spacing: 0.05em; | |
| color: var(--omada-text-muted); | |
| } | |
| .omada-advanced .gr-accordion { | |
| border: 1px solid var(--omada-border); | |
| border-radius: 12px; | |
| background: #ffffff; | |
| } | |
| .gradio-tabs { | |
| background: transparent; | |
| } | |
| .gradio-tabs ul.tab-list { | |
| background: transparent; | |
| border-bottom: 1px solid var(--omada-border); | |
| } | |
| .gradio-tabs button { | |
| color: var(--omada-text-primary); | |
| } | |
| .gradio-tabs button.selected { | |
| color: var(--omada-text-primary); | |
| background: rgba(14, 33, 80, 0.1); | |
| border-bottom: 2px solid var(--omada-primary); | |
| } | |
| .gradio-container .label { | |
| background: rgba(30, 58, 138, 0.1) !important; | |
| color: var(--omada-primary) !important; | |
| border: 1px solid rgba(30, 58, 138, 0.25) !important; | |
| border-radius: 999px !important; | |
| padding: 4px 12px !important; | |
| } | |
| .gradio-button.primary { | |
| background: var(--omada-primary) !important; | |
| color: #ffffff !important; | |
| border: 1px solid var(--omada-primary) !important; | |
| } | |
| .gradio-accordion { | |
| box-shadow: none; | |
| } | |
| .omada-layout { | |
| gap: 20px !important; | |
| } | |
| .omada-chat-column { | |
| gap: 12px !important; | |
| } | |
| .omada-chat-column .gradio-chatbot { | |
| border-radius: 16px; | |
| box-shadow: none; | |
| border: 1px solid var(--omada-border); | |
| background: #ffffff; | |
| } | |
| .omada-controls { | |
| gap: 16px !important; | |
| } | |
| .omada-mode-panel { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 16px !important; | |
| } | |
| .omada-examples-card { | |
| padding-top: 10px !important; | |
| } | |
| .omada-output-panel .gradio-audio, | |
| .omada-output-panel .gradio-textbox { | |
| margin-top: 8px; | |
| } | |
| .omada-response-container { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 10px; | |
| } | |
| .omada-response-status { | |
| margin: 0; | |
| font-weight: 600; | |
| font-size: 0.95rem; | |
| color: var(--omada-text-primary); | |
| } | |
| .omada-response-block, | |
| .omada-audio-block { | |
| background: rgba(30, 58, 138, 0.05); | |
| border-radius: 12px; | |
| padding: 12px 14px; | |
| color: var(--omada-text-primary); | |
| white-space: pre-wrap; | |
| word-break: break-word; | |
| } | |
| .omada-audio-block audio { | |
| width: 100%; | |
| } | |
| .omada-header { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 18px !important; | |
| margin-bottom: 18px; | |
| text-align: center; | |
| } | |
| .omada-header .gradio-image { | |
| background: transparent !important; | |
| border: none !important; | |
| } | |
| .omada-header img { | |
| object-fit: contain; | |
| display: block; | |
| } | |
| .omada-logo { | |
| max-width: 180px; | |
| padding: 0 !important; | |
| } | |
| .omada-examples { | |
| margin-top: 8px; | |
| padding-top: 4px; | |
| } | |
| .omada-logo .gradio-image, | |
| .omada-logo .gradio-image > div, | |
| .omada-logo .gradio-image .container { | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| padding: 0 !important; | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| .omada-logo img { | |
| width: 100%; | |
| height: auto; | |
| } | |
| .omada-logo button { | |
| display: none !important; | |
| } | |
| .gradio-container .gradio-component, | |
| .gradio-container .gradio-panel, | |
| .gradio-container .gradio-box { | |
| background: transparent !important; | |
| color: var(--omada-text-primary); | |
| } | |
| .dark .gradio-container, | |
| .dark .gradio-interface, | |
| .dark .gradio-container * { | |
| background-color: inherit !important; | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .dark .gradio-container .gradio-chatbot, | |
| .dark .gradio-container .gradio-dropdown, | |
| .dark .gradio-container .gradio-textbox, | |
| .dark .gradio-container .gradio-audio, | |
| .dark .gradio-container .gradio-video, | |
| .dark .gradio-container .gradio-slider, | |
| .dark .gradio-container .gradio-accordion, | |
| .dark .gradio-container .gradio-panel, | |
| .dark .gradio-container .gradio-box { | |
| background: #ffffff !important; | |
| border-color: var(--omada-border) !important; | |
| } | |
| .omada-title h2 { | |
| font-size: 2.4rem; | |
| font-weight: 700; | |
| color: var(--omada-text-primary); | |
| margin: 0; | |
| } | |
| .omada-title h3 { | |
| font-size: 1.25rem; | |
| font-weight: 600; | |
| letter-spacing: 0.1em; | |
| text-transform: uppercase; | |
| color: var(--omada-text-muted); | |
| margin: 6px 0 0; | |
| } | |
| .omada-tagline p { | |
| color: var(--omada-text-primary); | |
| font-size: 1rem; | |
| margin: 0; | |
| opacity: 0.9; | |
| } | |
| .gradio-container .prose :where(h1, h2, h3, h4, h5, h6) { | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .gradio-container .prose :where(p, li) { | |
| color: var(--omada-text-muted) !important; | |
| } | |
| .gradio-container label, .gradio-container span, .gradio-container button { | |
| color: var(--omada-text-primary); | |
| } | |
| .gradio-container .dark { | |
| background: #ffffff !important; | |
| color: var(--omada-text-primary) !important; | |
| } | |
| .omada-logo-img { | |
| max-width: 250px; | |
| width: 100%; | |
| height: auto; | |
| display: block; | |
| margin: 0 auto; | |
| } | |
| .omada-logo-wrapper { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| """ | |
| DEMO_ROOT = Path(__file__).resolve().parent / "demo" | |
| LOGO_PATH = DEMO_ROOT / "logo.png" | |
| T2S_TEXT_PATH = DEMO_ROOT / "t2s" / "text.txt" | |
| def _load_logo_data() -> Optional[str]: | |
| if not LOGO_PATH.exists(): | |
| return None | |
| try: | |
| import base64 | |
| except ImportError: | |
| return str(LOGO_PATH) | |
| try: | |
| encoded = base64.b64encode(LOGO_PATH.read_bytes()).decode("utf-8") | |
| except OSError: | |
| return str(LOGO_PATH) | |
| return f"data:image/png;base64,{encoded}" | |
| def _load_t2s_examples(): | |
| if not T2S_TEXT_PATH.exists(): | |
| return [] | |
| lines = [ | |
| line.strip() | |
| for line in T2S_TEXT_PATH.read_text(encoding="utf-8").splitlines() | |
| if line.strip() | |
| ] | |
| return [[line] for line in lines] | |
| def _load_media_examples(subdir: str, suffixes): | |
| target_dir = DEMO_ROOT / subdir | |
| if not target_dir.exists(): | |
| return [] | |
| examples = [] | |
| for path in sorted(target_dir.iterdir()): | |
| if path.is_file() and path.suffix.lower() in suffixes: | |
| examples.append([str(path)]) | |
| return examples | |
| T2S_EXAMPLES = _load_t2s_examples() | |
| S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"}) | |
| V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) | |
| LOGO_DATA_URI = _load_logo_data() | |
| def _render_response(status: str, body_html: str = "") -> str: | |
| safe_status = html.escape(status or "") | |
| parts = [] | |
| if safe_status: | |
| parts.append(f"<p class='omada-response-status'>{safe_status}</p>") | |
| if body_html: | |
| parts.append(body_html) | |
| content = "".join(parts) | |
| return f"<div class='omada-response-container'>{content}</div>" | |
| def _render_text_message(status: str, content: Optional[str]) -> str: | |
| content = (content or "").strip() | |
| if not content: | |
| return _render_response(status) | |
| safe_content = html.escape(content).replace("\n", "<br>") | |
| body = f"<div class='omada-response-block'>{safe_content}</div>" | |
| return _render_response(status, body) | |
| def _render_audio_message(status: str, audio: Optional[Tuple[int, np.ndarray]]) -> str: | |
| """Render an inline HTML audio player for chat responses.""" | |
| if not audio: | |
| return _render_response(status) | |
| sample_rate, data = audio | |
| if data is None: | |
| return _render_response(status) | |
| waveform = np.asarray(data, dtype=np.float32) | |
| if waveform.size == 0: | |
| return _render_response(status) | |
| if waveform.ndim == 1: | |
| waveform = waveform[:, None] | |
| channels = waveform.shape[1] | |
| clipped = np.clip(waveform, -1.0, 1.0) | |
| pcm16 = (clipped * 32767.0).astype(np.int16) | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, "wb") as wav_writer: | |
| wav_writer.setnchannels(channels) | |
| wav_writer.setsampwidth(2) # 16-bit PCM | |
| wav_writer.setframerate(int(sample_rate)) | |
| wav_writer.writeframes(pcm16.tobytes()) | |
| encoded = base64.b64encode(buffer.getvalue()).decode("ascii") | |
| audio_tag = ( | |
| "<div class='omada-audio-block'>" | |
| "<audio controls preload='auto' playsinline>" | |
| f"<source src='data:audio/wav;base64,{encoded}' type='audio/wav' /></audio>" | |
| "</div>" | |
| ) | |
| return _render_response(status, audio_tag) | |
| def _format_user_message(message: str) -> str: | |
| clean = html.escape(message or "") | |
| return clean.replace("\n", "<br>") | |
| # Ensure project modules (models, training, inference.common, β¦) are importable when the | |
| # script is launched directly via `python MMaDA/inference/gradio_multimodal_demo.py`. | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from inference.common import ( | |
| build_uni_prompting, | |
| get_vq_model_audio, | |
| get_vq_model_image, | |
| load_omada_from_checkpoint, | |
| load_train_config, | |
| ) | |
| from models import get_mask_schedule | |
| from training.data import S2T_INSTRUCTION, T2S_INSTRUCTION, V2T_INSTRUCTION | |
| from training.utils import image_transform | |
| def _resolve_noise_schedule(train_cfg) -> callable: | |
| """Return the diffusion noise schedule used for T2S sampling.""" | |
| schedule_cfg = getattr(train_cfg, "mask_schedule", None) | |
| if schedule_cfg and hasattr(schedule_cfg, "schedule"): | |
| schedule_name = schedule_cfg.schedule | |
| schedule_kwargs = schedule_cfg.get("params", {}) | |
| return get_mask_schedule(schedule_name, **schedule_kwargs) | |
| schedule_name = train_cfg.training.get("mask_schedule", "cosine") | |
| return get_mask_schedule(schedule_name) | |
| class OmadaDemo: | |
| """Lightweight container that loads all inference assets once.""" | |
| def __init__(self, train_config: str, checkpoint: str, device: Optional[str] = None): | |
| ckpt_path = Path(checkpoint) | |
| if ckpt_path.name != "unwrapped_model": | |
| raise ValueError( | |
| "`--checkpoint` must point to an `unwrapped_model` directory. " | |
| f"Received: {checkpoint}" | |
| ) | |
| self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) | |
| self.train_cfg = load_train_config(train_config) | |
| self.uni_prompting, _ = build_uni_prompting(self.train_cfg) | |
| # Core models | |
| self.model = load_omada_from_checkpoint(str(ckpt_path), self.device) | |
| self.vq_audio = get_vq_model_audio(self.train_cfg, self.device) | |
| self.vq_image = get_vq_model_image(self.train_cfg, self.device) | |
| self.model.eval() | |
| self.vq_audio.eval() | |
| self.vq_image.eval() | |
| # Cached constants | |
| self.mask_token_id = int(self.model.config.mask_token_id) | |
| self.noise_schedule = _resolve_noise_schedule(self.train_cfg) | |
| self.sample_rate = int(getattr(self.vq_audio.u2s_config.data, "sampling_rate", 22050)) | |
| self.genders = ['female', 'male'] | |
| self.emotions = ['angry', 'happy', 'neutral', 'sad'] | |
| self.speeds = ['normal', 'fast', 'slow'] | |
| self.pitches = ['normal', 'high', 'low'] | |
| # Pre-computed offsets reused across calls | |
| self.text_vocab_size = len(self.uni_prompting.text_tokenizer) | |
| self.codebook_size = int(getattr(self.train_cfg.model.omada, "codebook_size", 8192)) | |
| self.speech_codebook = self.codebook_size | |
| self._temp_video_files = [] | |
| # ------------------------------------------------------------------ | |
| # Text-to-Speech | |
| # ------------------------------------------------------------------ | |
| def run_t2s( | |
| self, | |
| text: str, | |
| max_new_tokens: int, | |
| steps: int, | |
| block_length: int, | |
| temperature: float, | |
| cfg_scale: float, | |
| gender_choice: str, | |
| emotion_choice: str, | |
| speed_choice: str, | |
| pitch_choice: str, | |
| ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: | |
| if text is None or not text.strip(): | |
| return None, "Please provide text to synthesize." | |
| speech_len = int(max_new_tokens) | |
| if speech_len <= 0: | |
| return None, "Speech token length must be positive." | |
| gender = self._resolve_choice(gender_choice, self.genders) | |
| emotion = self._resolve_choice(emotion_choice, self.emotions) | |
| speed = self._resolve_choice(speed_choice, self.speeds) | |
| pitch = self._resolve_choice(pitch_choice, self.pitches) | |
| text = text.strip().upper() | |
| prompt = ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{random.choice(T2S_INSTRUCTION)}\n{text}" | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| audio_tokens = torch.full( | |
| (1, speech_len), | |
| fill_value=self.mask_token_id, | |
| dtype=torch.long, | |
| device=self.device, | |
| ) | |
| input_ids, attention_mask = self.uni_prompting(([prompt], audio_tokens), "t2s_gen") | |
| input_ids = input_ids.to(self.device) | |
| attention_mask = attention_mask.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.t2s_generate_mmu_like( | |
| input_ids=input_ids, | |
| max_new_tokens=int(max_new_tokens), | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| temperature=float(temperature), | |
| cfg_scale=float(cfg_scale), | |
| mask_token_id=self.mask_token_id, | |
| attention_mask=attention_mask, | |
| uni_prompting=self.uni_prompting, | |
| codebook_size=self.codebook_size, | |
| ) | |
| if not outputs: | |
| return None, "Generation produced no speech tokens." | |
| rel = outputs[0] | |
| if isinstance(rel, torch.Tensor): | |
| rel_ids = rel.detach().cpu().tolist() | |
| else: | |
| rel_ids = list(rel) | |
| if not rel_ids: | |
| return None, "Generation produced no speech tokens." | |
| speech_units = "".join(f"<|speech_{sid}|>" for sid in rel_ids) | |
| condition = f"gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}" | |
| wav = self.vq_audio.decode( | |
| speech_units, | |
| condition=condition, | |
| output_wav_file=os.path.join("/tmp", "omada_t2s.wav"), | |
| ) | |
| audio = (self.sample_rate, wav.astype(np.float32)) | |
| status = f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})." | |
| return audio, status | |
| # ------------------------------------------------------------------ | |
| # Speech-to-Text | |
| # ------------------------------------------------------------------ | |
| def run_s2t( | |
| self, | |
| audio_path: Optional[str], | |
| steps: int, | |
| block_length: int, | |
| max_new_tokens: int, | |
| remasking: str, | |
| ) -> Tuple[str, str]: | |
| if not audio_path: | |
| return "", "Please upload an audio file first." | |
| tokens = self.vq_audio.encode(audio_path).to(self.device) | |
| offset = self.text_vocab_size + self.speech_codebook | |
| tokens = tokens + offset | |
| spt = self.uni_prompting.sptids_dict | |
| audio_block = torch.cat( | |
| [ | |
| spt['<|s2t|>'].to(self.device).unsqueeze(0), | |
| spt['<|soa|>'].to(self.device).unsqueeze(0), | |
| tokens.to(self.device), | |
| spt['<|eoa|>'].to(self.device).unsqueeze(0), | |
| ], | |
| dim=1, | |
| ) | |
| prompt_text = random.choice(S2T_INSTRUCTION) | |
| chat_prompt = ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{prompt_text}" | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| prompt_tensor = self.uni_prompting.text_tokenizer( | |
| chat_prompt, | |
| return_tensors="pt", | |
| ).input_ids.to(self.device) | |
| input_ids = torch.cat([audio_block, prompt_tensor], dim=1) | |
| with torch.no_grad(): | |
| output_ids = self.model.mmu_generate( | |
| input_ids, | |
| max_new_tokens=int(max_new_tokens), | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| remasking=str(remasking), | |
| ) | |
| decoded = self.uni_prompting.text_tokenizer.batch_decode( | |
| output_ids[:, input_ids.shape[1]:], | |
| skip_special_tokens=True, | |
| )[0] | |
| return decoded.strip(), "Transcription generated successfully." | |
| # ------------------------------------------------------------------ | |
| # Video-to-Text | |
| # ------------------------------------------------------------------ | |
| def run_v2t( | |
| self, | |
| video_path: Any, | |
| steps: int, | |
| block_length: int, | |
| max_new_tokens: int, | |
| ) -> Tuple[str, str]: | |
| resolved_path, converted = self._prepare_video_path(video_path) | |
| if not resolved_path: | |
| return "", "Please upload or record a video file first." | |
| try: | |
| video_tokens = self._extract_video_tokens(resolved_path) | |
| except Exception as exc: | |
| return "", f"Failed to process video: {exc}" | |
| spt = self.uni_prompting.sptids_dict | |
| question = random.choice(V2T_INSTRUCTION) | |
| prompt = ( | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{question}" | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| prompt_ids = self.uni_prompting.text_tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| ).input_ids.to(self.device) | |
| input_ids = torch.cat( | |
| [ | |
| spt['<|v2t|>'].to(self.device).unsqueeze(0), | |
| spt['<|soi|>'].to(self.device).unsqueeze(0), | |
| video_tokens, | |
| spt['<|eoi|>'].to(self.device).unsqueeze(0), | |
| spt['<|sot|>'].to(self.device).unsqueeze(0), | |
| prompt_ids, | |
| ], | |
| dim=1, | |
| ).long() | |
| with torch.no_grad(): | |
| output_ids = self.model.mmu_generate( | |
| input_ids, | |
| max_new_tokens=int(max_new_tokens), | |
| steps=int(steps), | |
| block_length=int(block_length), | |
| ) | |
| decoded = self.uni_prompting.text_tokenizer.batch_decode( | |
| output_ids[:, input_ids.shape[1]:], | |
| skip_special_tokens=True, | |
| )[0] | |
| status_msg = "Video caption generated successfully." | |
| if converted: | |
| status_msg += " (Webcam recording converted to MP4.)" | |
| return decoded.strip(), status_msg | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def _resolve_choice(self, choice: Optional[str], options): | |
| if choice is None or choice == 'random': | |
| return random.choice(options) | |
| return choice | |
| def _prepare_video_path(self, video_input: Any) -> Tuple[Optional[str], bool]: | |
| """Normalize Gradio video inputs (upload/webcam) to an MP4 filepath.""" | |
| candidate = None | |
| if isinstance(video_input, str): | |
| candidate = video_input | |
| elif isinstance(video_input, dict): | |
| candidate = ( | |
| video_input.get("video") | |
| or video_input.get("name") | |
| or video_input.get("path") | |
| ) | |
| elif isinstance(video_input, (list, tuple)) and video_input: | |
| candidate = str(video_input[0]) | |
| if not candidate: | |
| return None, False | |
| candidate = str(candidate) | |
| if not self._ensure_file_ready(candidate): | |
| return None, False | |
| if candidate.lower().endswith(".mp4"): | |
| return candidate, False | |
| converted = self._convert_to_mp4(candidate) | |
| if converted: | |
| return converted, True | |
| suffix = Path(candidate).suffix or ".webm" | |
| fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) | |
| os.close(fd) | |
| try: | |
| shutil.copy2(candidate, tmp_path) | |
| self._temp_video_files.append(tmp_path) | |
| return tmp_path, False | |
| except OSError: | |
| try: | |
| os.remove(tmp_path) | |
| except OSError: | |
| pass | |
| return candidate, False | |
| if candidate.lower().endswith(".mp4"): | |
| return candidate, False | |
| converted = self._convert_to_mp4(candidate) | |
| if converted: | |
| return converted, True | |
| suffix = Path(candidate).suffix or ".webm" | |
| fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) | |
| os.close(fd) | |
| try: | |
| shutil.copy2(candidate, tmp_path) | |
| self._temp_video_files.append(tmp_path) | |
| return tmp_path, False | |
| except OSError: | |
| try: | |
| os.remove(tmp_path) | |
| except OSError: | |
| pass | |
| return candidate, False | |
| def _ensure_file_ready(self, path: str, retries: int = 8, delay: float = 0.2) -> bool: | |
| """Ensure the uploaded/recorded file is fully written before processing.""" | |
| prev_size = -1 | |
| for _ in range(retries): | |
| try: | |
| size = os.path.getsize(path) | |
| except OSError: | |
| size = -1 | |
| if size <= 0: | |
| time.sleep(delay) | |
| continue | |
| if size == prev_size: | |
| return True | |
| prev_size = size | |
| time.sleep(delay) | |
| return prev_size > 0 | |
| def _convert_to_mp4(self, src_path: str) -> Optional[str]: | |
| """Convert arbitrary video file to MP4 using OpenCV (drops audio).""" | |
| cap = cv2.VideoCapture(src_path) | |
| if not cap.isOpened(): | |
| return None | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if width <= 0 or height <= 0: | |
| cap.release() | |
| return None | |
| if not fps or np.isnan(fps) or fps <= 0: | |
| fps = 24.0 | |
| fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_", suffix=".mp4") | |
| os.close(fd) | |
| writer = cv2.VideoWriter( | |
| tmp_path, | |
| cv2.VideoWriter_fourcc(*"mp4v"), | |
| float(fps), | |
| (width, height), | |
| ) | |
| if not writer.isOpened(): | |
| cap.release() | |
| try: | |
| os.remove(tmp_path) | |
| except OSError: | |
| pass | |
| return None | |
| frame_count = 0 | |
| try: | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| writer.write(frame) | |
| frame_count += 1 | |
| finally: | |
| cap.release() | |
| writer.release() | |
| if frame_count == 0: | |
| try: | |
| os.remove(tmp_path) | |
| except OSError: | |
| pass | |
| return None | |
| self._temp_video_files.append(tmp_path) | |
| return tmp_path | |
| def _extract_video_tokens(self, video_path: str) -> torch.Tensor: | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames <= 0: | |
| cap.release() | |
| raise RuntimeError(f"No readable frames in {video_path}") | |
| indices = np.linspace(0, total_frames - 1, 8, dtype=int) | |
| frames = [] | |
| for idx in range(total_frames): | |
| ret, frame = cap.read() | |
| if idx in indices and ret: | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil = Image.fromarray(rgb) | |
| frames.append(image_transform(pil, resolution=self.train_cfg.dataset.preprocessing.resolution)) | |
| cap.release() | |
| if len(frames) == 0: | |
| raise RuntimeError("Failed to sample frames for V2T inference.") | |
| video_tensor = torch.stack(frames).to(self.device) | |
| video_tokens = self.vq_image.get_code(video_tensor) + self.text_vocab_size | |
| return video_tokens.long().to(self.device).view(1, -1) | |
| def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optional[int]): | |
| theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray") | |
| with gr.Blocks(title="OMada Stage1.3 Audio/Video Demo", css=CUSTOM_CSS, theme=theme) as demo: | |
| with gr.Column(elem_classes=["omada-header"]): | |
| if LOGO_DATA_URI: | |
| gr.HTML( | |
| f"<div class='omada-logo-wrapper'><img src=\"{LOGO_DATA_URI}\" alt=\"AIDAS Lab\" class=\"omada-logo-img\" /></div>" | |
| ) | |
| elif LOGO_PATH.exists(): | |
| gr.Image( | |
| value=str(LOGO_PATH), | |
| show_label=False, | |
| height=140, | |
| interactive=False, | |
| elem_classes=["omada-logo"], | |
| ) | |
| gr.Markdown( | |
| "## Omni-modal Diffusion Foundation Model\n### Pretrained Demo", | |
| elem_classes=["omada-title"], | |
| ) | |
| gr.Markdown( | |
| "Create speech, transcribe audio, and describe video from a single model. " | |
| "Use the advanced sections when you want tighter control.", | |
| elem_classes=["omada-tagline"], | |
| ) | |
| with gr.Row(elem_classes=["omada-layout"], equal_height=False): | |
| with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]): | |
| chatbox = gr.Chatbot(label="Session", height=420, sanitize_html=False) | |
| placeholder_map = { | |
| "Text β Speech": "Type the speech you want to generate...", | |
| "Speech β Text": "Upload audio on the right, then leave notes here if needed.", | |
| "Video β Text": "Upload video on the right, then leave notes here if needed.", | |
| } | |
| chat_input = gr.Textbox( | |
| label="Message", | |
| placeholder=placeholder_map["Text β Speech"], | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| send_button = gr.Button("Send", variant="primary") | |
| clear_button = gr.Button("Clear", variant="secondary") | |
| with gr.Column(scale=2, min_width=360, elem_classes=["omada-controls"]): | |
| mode_selector = gr.Dropdown( | |
| ["Text β Speech", "Speech β Text", "Video β Text"], | |
| value="Text β Speech", | |
| label="Mode", | |
| ) | |
| with gr.Column(visible=True, elem_classes=["omada-mode-panel"]) as t2s_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Text-to-Speech Controls") | |
| with gr.Group(elem_classes=["omada-advanced"]): | |
| gr.Markdown("**Generation**") | |
| with gr.Row(): | |
| t2s_max_tokens = gr.Slider(2, 512, value=128, label="Speech token length", step=2) | |
| t2s_steps = gr.Slider(2, 512, value=128, label="Total refinement steps", step=2) | |
| with gr.Row(): | |
| t2s_block = gr.Slider(2, 512, value=128, label="Block length", step=2) | |
| t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, label="CFG scale", step=0.1) | |
| t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) | |
| with gr.Group(elem_classes=["omada-advanced"]): | |
| gr.Markdown("**Voice styling**") | |
| with gr.Row(): | |
| t2s_gender = gr.Dropdown(['random'] + app.genders, value='random', label="Voice gender") | |
| t2s_emotion = gr.Dropdown(['random'] + app.emotions, value='random', label="Emotion") | |
| with gr.Row(): | |
| t2s_speed = gr.Dropdown(['random'] + app.speeds, value='random', label="Speaking speed") | |
| t2s_pitch = gr.Dropdown(['random'] + app.pitches, value='random', label="Pitch") | |
| if T2S_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample prompts**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=T2S_EXAMPLES, | |
| inputs=[chat_input], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2t_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Speech-to-Text Controls") | |
| s2t_audio = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"]) | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): | |
| with gr.Row(): | |
| s2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2) | |
| s2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2) | |
| s2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2) | |
| s2t_remasking = gr.Dropdown( | |
| choices=["low_confidence", "random"], | |
| value="low_confidence", | |
| label="Remasking strategy", | |
| ) | |
| if S2T_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample clips**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=S2T_EXAMPLES, | |
| inputs=[s2t_audio], | |
| examples_per_page=4, | |
| ) | |
| with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2t_panel: | |
| with gr.Group(elem_classes=["omada-card"]): | |
| gr.Markdown("### Video-to-Text Controls") | |
| v2t_video = gr.Video( | |
| label="Upload or record video", | |
| format=None, | |
| height=256, | |
| sources=["upload", "webcam"], | |
| ) | |
| with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): | |
| with gr.Row(): | |
| v2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2) | |
| v2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2) | |
| v2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2) | |
| if V2T_EXAMPLES: | |
| with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): | |
| gr.Markdown("**Sample videos**") | |
| with gr.Column(elem_classes=["omada-examples"]): | |
| gr.Examples( | |
| examples=V2T_EXAMPLES, | |
| inputs=[v2t_video], | |
| examples_per_page=4, | |
| ) | |
| def _toggle_controls(mode: str): | |
| return ( | |
| gr.update(visible=mode == "Text β Speech"), | |
| gr.update(visible=mode == "Speech β Text"), | |
| gr.update(visible=mode == "Video β Text"), | |
| gr.update(placeholder=placeholder_map.get(mode, chat_input.placeholder)), | |
| ) | |
| mode_selector.change( | |
| _toggle_controls, | |
| inputs=[mode_selector], | |
| outputs=[t2s_panel, s2t_panel, v2t_panel, chat_input], | |
| ) | |
| def _chat_handler( | |
| history, | |
| message, | |
| mode, | |
| audio_path, | |
| video_path, | |
| t2s_max_tokens, | |
| t2s_steps, | |
| t2s_block, | |
| t2s_temperature, | |
| t2s_cfg, | |
| t2s_gender, | |
| t2s_emotion, | |
| t2s_speed, | |
| t2s_pitch, | |
| s2t_steps, | |
| s2t_block, | |
| s2t_max_tokens, | |
| s2t_remasking, | |
| v2t_steps, | |
| v2t_block, | |
| v2t_max_tokens, | |
| ): | |
| history = history or [] | |
| message = (message or "").strip() | |
| response = "" | |
| if mode == "Text β Speech": | |
| if not message: | |
| status = "Please type some text for speech synthesis." | |
| response = _render_text_message(status, "") | |
| else: | |
| audio_result, status = app.run_t2s( | |
| message, | |
| t2s_max_tokens, | |
| t2s_steps, | |
| t2s_block, | |
| t2s_temperature, | |
| t2s_cfg, | |
| t2s_gender, | |
| t2s_emotion, | |
| t2s_speed, | |
| t2s_pitch, | |
| ) | |
| response = _render_audio_message(status, audio_result) | |
| display_user_raw = message or "[Speech request]" | |
| elif mode == "Speech β Text": | |
| if not audio_path: | |
| status = "Please upload or record an audio clip first." | |
| response = _render_text_message(status, "") | |
| else: | |
| transcript, status = app.run_s2t( | |
| audio_path, | |
| s2t_steps, | |
| s2t_block, | |
| s2t_max_tokens, | |
| s2t_remasking, | |
| ) | |
| response = _render_text_message(status, transcript) | |
| display_user_raw = message or "[Audio transcription request]" | |
| else: # Video β Text | |
| if not video_path: | |
| status = "Please upload or record a video first." | |
| response = _render_text_message(status, "") | |
| else: | |
| caption, status = app.run_v2t( | |
| video_path, | |
| v2t_steps, | |
| v2t_block, | |
| v2t_max_tokens, | |
| ) | |
| response = _render_text_message(status, caption) | |
| display_user_raw = message or "[Video caption request]" | |
| display_user = _format_user_message(display_user_raw) | |
| history = history + [(display_user, response)] | |
| return history, "" | |
| submit_inputs = [ | |
| chatbox, | |
| chat_input, | |
| mode_selector, | |
| s2t_audio, | |
| v2t_video, | |
| t2s_max_tokens, | |
| t2s_steps, | |
| t2s_block, | |
| t2s_temperature, | |
| t2s_cfg, | |
| t2s_gender, | |
| t2s_emotion, | |
| t2s_speed, | |
| t2s_pitch, | |
| s2t_steps, | |
| s2t_block, | |
| s2t_max_tokens, | |
| s2t_remasking, | |
| v2t_steps, | |
| v2t_block, | |
| v2t_max_tokens, | |
| ] | |
| submit_outputs = [chatbox, chat_input] | |
| chat_input.submit(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) | |
| send_button.click(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) | |
| def _clear_session(): | |
| return ( | |
| [], | |
| "", | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| ) | |
| clear_button.click( | |
| _clear_session, | |
| inputs=None, | |
| outputs=[chatbox, chat_input, s2t_audio, v2t_video], | |
| ) | |
| demo.launch( | |
| share=share, | |
| server_name=server_name, | |
| server_port=server_port, | |
| ) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="OMada Gradio demo for audio/video tasks") | |
| parser.add_argument("--train-config", required=True, help="Path to the training YAML used to build tokenizer + VQ modules") | |
| parser.add_argument("--checkpoint", required=True, help="Path to an `unwrapped_model` directory") | |
| parser.add_argument("--device", default=None, help="Override device (e.g. cuda:0). Defaults to CUDA if available") | |
| parser.add_argument("--share", action="store_true", help="Enable public Gradio share link") | |
| parser.add_argument("--server-name", default="0.0.0.0", help="Host address for Blocks.launch") | |
| parser.add_argument("--server-port", type=int, default=None, help="Port for Blocks.launch") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| app = OmadaDemo(args.train_config, args.checkpoint, args.device) | |
| build_demo(app, args.share, args.server_name, args.server_port) | |
| if __name__ == "__main__": | |
| main() | |