Spaces:
Running
on
Zero
Running
on
Zero
| import shutil | |
| import gradio as gr | |
| import spaces | |
| import yt_dlp | |
| import os | |
| import tempfile | |
| import re | |
| import subprocess | |
| import socket | |
| import time | |
| import atexit | |
| import torch | |
| from transformers import AutoModel, AutoProcessor | |
| PROXY_URL = None | |
| _tunnel_proc = None | |
| def _write_temp_key_and_kh(key_str, kh_line): | |
| key_clean = key_str.replace("\r\n", "\n").replace("\r", "\n") | |
| if not key_clean.endswith("\n"): | |
| key_clean += "\n" | |
| keyf = tempfile.NamedTemporaryFile("w", delete=False) | |
| keyf.write(key_clean) | |
| keyf.flush() | |
| os.chmod(keyf.name, 0o600) | |
| keyf.close() | |
| khf = tempfile.NamedTemporaryFile("w", delete=False) | |
| khf.write(kh_line.strip() + "\n") | |
| khf.flush() | |
| khf.close() | |
| return keyf.name, khf.name | |
| def _validate_private_key(path): | |
| if not shutil.which("ssh-keygen"): | |
| return True | |
| try: | |
| subprocess.check_output(["ssh-keygen", "-y", "-f", path], stderr=subprocess.STDOUT) | |
| return True | |
| except subprocess.CalledProcessError: | |
| return False | |
| def _ensure_local_socks_tunnel(): | |
| global PROXY_URL, _tunnel_proc | |
| if PROXY_URL: | |
| return | |
| srv = os.getenv("SSH_SERVER") | |
| port = os.getenv("SSH_PORT", "22") | |
| key = os.getenv("SSH_PRIVATE_KEY") | |
| hk = os.getenv("SSH_HOSTKEY") | |
| if not (srv and key and hk and shutil.which("ssh")): | |
| return | |
| key_path, kh_path = _write_temp_key_and_kh(key, hk) | |
| if not _validate_private_key(key_path): | |
| return | |
| cmd = [ | |
| "ssh","-NT","-p", port,"-i", key_path, | |
| "-D","127.0.0.1:1080", | |
| "-o","IdentitiesOnly=yes", | |
| "-o","ExitOnForwardFailure=yes", | |
| "-o","BatchMode=yes", | |
| "-o","StrictHostKeyChecking=yes", | |
| "-o", f"UserKnownHostsFile={kh_path}", | |
| "-o","GlobalKnownHostsFile=/dev/null", | |
| "-o","ServerAliveInterval=30","-o","ServerAliveCountMax=3", | |
| srv, | |
| ] | |
| with open("/tmp/ssh_tunnel.log", "w") as lf: | |
| _tunnel_proc = subprocess.Popen(cmd, stdout=lf, stderr=lf) | |
| for _ in range(40): | |
| if _tunnel_proc.poll() is not None: | |
| return | |
| try: | |
| socket.create_connection(("127.0.0.1", 1080), 0.5).close() | |
| PROXY_URL = "socks5h://127.0.0.1:1080" | |
| break | |
| except OSError: | |
| time.sleep(0.25) | |
| atexit.register(lambda: _tunnel_proc and _tunnel_proc.terminate()) | |
| _ensure_local_socks_tunnel() | |
| MODEL_ID = "nvidia/music-flamingo-hf" | |
| HERO_IMAGE_URL = "https://musicflamingo.github.io/logo-no-bg.png" | |
| HERO_TITLE = "Music Flamingo: Scaling Music Understanding in Audio Language Models" | |
| HERO_SUBTITLE = "Upload a song and ask anything — including captions, lyrics, genre, key, chords, or complex questions. Music Flamingo gives detailed answers." | |
| HERO_AUTHORS = """ | |
| <div style="margin-top: 8px; margin-bottom: 4px; padding: 8px 20px; text-align: center; max-width: 900px; margin-inline: auto;"> | |
| <p style="font-size: 0.95rem; line-height: 1.6; margin-bottom: 10px;"> | |
| <strong>Authors:</strong> Sreyan Ghosh<sup>1,2*</sup>, Arushi Goel<sup>1*</sup>, Lasha Koroshinadze<sup>2**</sup>, Sang-gil Lee<sup>1</sup>, Zhifeng Kong<sup>1</sup>, Joao Felipe Santos<sup>1</sup>,<br>Ramani Duraiswami<sup>2</sup>, Dinesh Manocha<sup>2</sup>, Wei Ping<sup>1</sup>, Mohammad Shoeybi<sup>1</sup>, Bryan Catanzaro<sup>1</sup> | |
| </p> | |
| <p style="font-size: 0.88rem; opacity: 0.75; margin-bottom: 8px;"> | |
| <sup>1</sup>NVIDIA, CA, USA | <sup>2</sup>University of Maryland, College Park, USA | |
| </p> | |
| <p style="font-size: 0.82rem; opacity: 0.65; font-style: italic; margin-bottom: 6px;"> | |
| *Equally contributed and led the project. Names randomly ordered. **Significant technical contribution. | |
| </p> | |
| <p style="font-size: 0.85rem; opacity: 0.7; margin-bottom: 0;"> | |
| <strong>Correspondence:</strong> <a href="mailto:[email protected]" style="color: inherit; text-decoration: underline;">[email protected]</a>, <a href="mailto:[email protected]" style="color: inherit; text-decoration: underline;">[email protected]</a> | |
| </p> | |
| </div> | |
| """ | |
| HERO_BADGES = """ | |
| <div style="display: flex; justify-content: center; margin-top: 6px; align-items: center;"> | |
| <div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 8px;"> | |
| <a href="https://arxiv.org/abs/2511.10289"><img src="https://img.shields.io/badge/arXiv-2511.10289-AD1C18" alt="arXiv"></a> | |
| <a href="https://research.nvidia.com/labs/adlr/MF/"><img src="https://img.shields.io/badge/Demo page-228B22" alt="Demo page"></a> | |
| <a href="https://github.com/NVIDIA/audio-flamingo"><img src='https://img.shields.io/badge/Github-Audio Flamingo 3-9C276A' alt="Github"></a> | |
| <a href="https://github.com/NVIDIA/audio-flamingo/stargazers"><img src="https://img.shields.io/github/stars/NVIDIA/audio-flamingo.svg?style=social" alt="Stars"></a> | |
| <a href="https://huggingface.co/nvidia/music-flamingo-hf"> | |
| <img src="https://img.shields.io/badge/🤗-Checkpoints-ED5A22.svg" alt="Checkpoints"> | |
| </a> | |
| <a href="https://huggingface.co/datasets/nvidia/MF-Skills"> | |
| <img src="https://img.shields.io/badge/🤗-Dataset: MF--Skills-ED5A22.svg" alt="Dataset"> | |
| </a> | |
| </div> | |
| </div> | |
| """ | |
| APP_CSS = """ | |
| :root { | |
| --font-sans: ui-sans-serif, system-ui, sans-serif, | |
| "Apple Color Emoji", "Segoe UI Emoji", | |
| "Segoe UI Symbol", "Noto Color Emoji"; | |
| --font-mono: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, | |
| "Liberation Mono", "Courier New", monospace; | |
| --app-font: var(--font-sans); | |
| } | |
| body { | |
| font-family: var(--app-font); | |
| } | |
| .gradio-container { | |
| font-family: var(--app-font); | |
| max-width: 80rem !important; /* Tailwind max-w-7xl (1280px) */ | |
| width: 100%; | |
| margin-inline: auto; /* mx-auto */ | |
| padding-inline: 1rem; /* px-4 */ | |
| padding-bottom: 64px; | |
| } | |
| .hero { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| gap: 12px; | |
| padding: 24px 24px 32px; | |
| text-align: center; | |
| } | |
| .hero__logo { | |
| width: 112px; | |
| height: 112px; | |
| border-radius: 50%; | |
| box-shadow: 0 12px 40px rgba(0, 0, 0, 0.15); | |
| } | |
| .hero__title { | |
| font-size: clamp(2.4rem, 5.4vw, 3.2rem); | |
| font-weight: 700; | |
| line-height: 1.5; | |
| letter-spacing: -0.01em; | |
| background: linear-gradient(120deg, #ff6bd6 0%, #af66ff 35%, #4e9cff 100%); | |
| -webkit-background-clip: text; | |
| background-clip: text; | |
| color: transparent; | |
| } | |
| .hero__subtitle { | |
| max-width: none; | |
| font-size: 1.08rem; | |
| opacity: 0.8; | |
| } | |
| .tab-nav { | |
| border-radius: 18px; | |
| border: 1px solid var(--border-color-primary); | |
| padding: 6px; | |
| margin: 0 18px 12px; | |
| } | |
| .tab-nav button { | |
| border-radius: 12px !important; | |
| } | |
| .tab-nav button[aria-selected="true"] { | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
| } | |
| .panel-row { | |
| gap: 24px !important; | |
| align-items: stretch; | |
| flex-wrap: wrap; | |
| } | |
| .glass-card { | |
| border: 1px solid var(--border-color-primary); | |
| border-radius: 26px; | |
| padding: 28px; | |
| box-shadow: 0 8px 25px rgba(0, 0, 0, 0.1); | |
| display: flex; | |
| flex-direction: column; | |
| gap: 18px; | |
| } | |
| /* Glass card content styling */ | |
| .glass-card .gradio-input, | |
| .glass-card .gradio-output { | |
| /* Let Gradio handle default styling */ | |
| } | |
| .glass-card label { | |
| font-weight: 600; | |
| letter-spacing: 0.01em; | |
| } | |
| /* Text input styling */ | |
| .glass-card textarea { | |
| border-radius: 18px !important; | |
| } | |
| .glass-card textarea:focus { | |
| box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.25) !important; | |
| } | |
| /* Audio component fix */ | |
| .glass-card [data-testid="Audio"] .wrap { | |
| /* Let Gradio handle default styling */ | |
| } | |
| /* YouTube embed styling */ | |
| .glass-card [data-testid="HTML"] { | |
| margin: 12px 0; | |
| } | |
| /* Load button styling */ | |
| .glass-card button[variant="secondary"] { | |
| border-radius: 12px !important; | |
| font-weight: 500 !important; | |
| } | |
| /* Action button styling */ | |
| .accent-button { | |
| background: linear-gradient(120deg, #ff6bd6 0%, #8f5bff 45%, #4e9cff 100%) !important; | |
| border-radius: 14px !important; | |
| box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15); | |
| color: #ffffff !important; | |
| font-weight: 600 !important; | |
| letter-spacing: 0.01em; | |
| padding: 0.85rem 1.5rem !important; | |
| transition: transform 0.18s ease, box-shadow 0.18s ease; | |
| } | |
| .accent-button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2); | |
| } | |
| .accent-button:active { | |
| transform: translateY(0px); | |
| box-shadow: 0 4px 15px rgba(0, 0, 0, 0.15); | |
| } | |
| .footer-note { | |
| text-align: center; | |
| opacity: 0.6; | |
| margin-top: 28px; | |
| font-size: 0.95rem; | |
| } | |
| """ | |
| EXAMPLE_YOUTUBE_PROMPTS = [ | |
| [ | |
| "https://youtu.be/ko70cExuzZM", | |
| "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", | |
| ], | |
| [ | |
| "https://youtu.be/iywaBOMvYLI", | |
| "Generate a structured lyric sheet from the input music.", | |
| ], | |
| [ | |
| "https://youtu.be/_mTRvJ9fugM", | |
| "Which line directly precedes the chorus?", | |
| ], | |
| ] | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModel.from_pretrained(MODEL_ID, device_map="cpu", torch_dtype=torch.float32).eval() | |
| _youtube_cache = {} | |
| def clear_youtube_cache(): | |
| """Clear the YouTube audio cache and delete cached files.""" | |
| import shutil | |
| for url, (file_path, title) in _youtube_cache.items(): | |
| try: | |
| if os.path.exists(file_path): | |
| temp_dir = os.path.dirname(file_path) | |
| shutil.rmtree(temp_dir) | |
| except Exception: | |
| pass | |
| _youtube_cache.clear() | |
| def truncate_title(title, max_length=50): | |
| """Truncate long titles with ellipsis to prevent UI wrapping.""" | |
| if len(title) <= max_length: | |
| return title | |
| return title[: max_length - 3] + "..." | |
| def extract_youtube_id(url): | |
| """Extract YouTube video ID from various YouTube URL formats.""" | |
| patterns = [ | |
| r"(?:https?://)?(?:www\.)?youtube\.com/watch\?v=([^&=%\?]{11})", | |
| r"(?:https?://)?(?:www\.)?youtu\.be/([^&=%\?]{11})", | |
| r"(?:https?://)?(?:www\.)?youtube\.com/embed/([^&=%\?]{11})", | |
| r"(?:https?://)?(?:www\.)?youtube-nocookie\.com/embed/([^&=%\?]{11})", | |
| r"(?:https?://)?(?:www\.)?youtube\.com/v/([^&=%\?]{11})", | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, url) | |
| if match: | |
| return match.group(1) | |
| return None | |
| def generate_youtube_embed(url, title="YouTube Video"): | |
| """Generate YouTube embed HTML from URL.""" | |
| video_id = extract_youtube_id(url) | |
| if not video_id: | |
| return "" | |
| embed_html = f""" | |
| <div style="position: relative; width: 100%; height: 0; padding-bottom: 56.25%; border-radius: 12px; overflow: hidden; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.3);"> | |
| <iframe | |
| style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;" | |
| src="https://www.youtube.com/embed/{video_id}" | |
| title="{title}" | |
| frameborder="0" | |
| allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" | |
| referrerpolicy="strict-origin-when-cross-origin" | |
| allowfullscreen> | |
| </iframe> | |
| </div> | |
| """ | |
| return embed_html | |
| def download_youtube_audio(url, force_reload=False): | |
| """Download audio from YouTube URL and return the file path.""" | |
| try: | |
| youtube_regex = re.compile(r"(https?://)?(www\.)?(youtube|youtu|youtube-nocookie)\.(com|be)/" r"(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})") | |
| if not youtube_regex.match(url): | |
| return None, "❌ Invalid YouTube URL format" | |
| if not force_reload and url in _youtube_cache: | |
| cached_path, cached_title = _youtube_cache[url] | |
| if os.path.exists(cached_path): | |
| return cached_path, f"✅ Using cached: {truncate_title(cached_title)}" | |
| if force_reload and url in _youtube_cache: | |
| old_path, _ = _youtube_cache[url] | |
| try: | |
| if os.path.exists(old_path): | |
| import shutil | |
| temp_dir = os.path.dirname(old_path) | |
| shutil.rmtree(temp_dir) | |
| except Exception: | |
| pass | |
| del _youtube_cache[url] | |
| temp_dir = tempfile.mkdtemp() | |
| ydl_opts = { | |
| "format": "bestaudio/best", | |
| "outtmpl": os.path.join(temp_dir, "%(title)s.%(ext)s"), | |
| "postprocessors": [ | |
| { | |
| "key": "FFmpegExtractAudio", | |
| "preferredcodec": "mp3", | |
| "preferredquality": "128", | |
| } | |
| ], | |
| "noplaylist": True, | |
| } | |
| if PROXY_URL: | |
| ydl_opts["proxy"] = PROXY_URL | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| info = ydl.extract_info(url, download=False) | |
| title = info.get("title", "Unknown") | |
| ydl.download([url]) | |
| for file in os.listdir(temp_dir): | |
| if file.endswith(".mp3"): | |
| file_path = os.path.join(temp_dir, file) | |
| _youtube_cache[url] = (file_path, title) | |
| return file_path, f"✅ Downloaded: {truncate_title(title)}" | |
| return None, "❌ Failed to download audio file" | |
| except Exception as e: | |
| return None, f"❌ Download error: {str(e)}" | |
| def infer(audio_path, youtube_url, prompt_text): | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| final_audio_path = None | |
| status_message = "" | |
| if audio_path: | |
| final_audio_path = audio_path | |
| status_message = "✅ Using audio file" | |
| elif youtube_url.strip(): | |
| final_audio_path, status_message = download_youtube_audio(youtube_url.strip()) | |
| if not final_audio_path: | |
| return status_message | |
| else: | |
| return "❌ Please either upload an audio file or provide a YouTube URL." | |
| conversations = [ | |
| [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt_text or ""}, | |
| {"type": "audio", "path": final_audio_path}, | |
| ], | |
| } | |
| ] | |
| ] | |
| # NOTE: If `conversations` includes audio, apply_chat_template() decodes via load_audio() | |
| # to MONO float32 at 16 kHz by default. We omit `sampling_rate`, so the 16k default is used. | |
| # Processor assumes mono 1-D audio; stereo would require code changes. No audio ⇒ no effect here. | |
| batch = processor.apply_chat_template( | |
| conversations, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| ).to(model.device) | |
| gen_ids = model.generate(**batch, max_new_tokens=4096, repetition_penalty=1.2) | |
| inp_len = batch["input_ids"].shape[1] | |
| new_tokens = gen_ids[:, inp_len:] | |
| texts = processor.batch_decode(new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| result = texts[0] if texts else "" | |
| return f"{status_message}\n\n{result}" | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}" | |
| def load_youtube_audio(youtube_url): | |
| """Load YouTube audio into the Audio component and generate video embed.""" | |
| if not youtube_url.strip(): | |
| return None, "❌ Please enter a YouTube URL", "" | |
| embed_html = generate_youtube_embed(youtube_url.strip()) | |
| audio_path, message = download_youtube_audio(youtube_url.strip(), force_reload=True) | |
| if audio_path: | |
| return audio_path, message, embed_html | |
| else: | |
| return None, message, embed_html | |
| with gr.Blocks(css=APP_CSS, theme=gr.themes.Soft(primary_hue="purple", secondary_hue="fuchsia")) as demo: | |
| gr.HTML( | |
| f""" | |
| <div class="hero"> | |
| <img src="{HERO_IMAGE_URL}" alt="Music Flamingo logo" class="hero__logo" /> | |
| <h1 class="hero__title">{HERO_TITLE}</h1> | |
| <p class="hero__subtitle">{HERO_SUBTITLE}</p> | |
| {HERO_AUTHORS} | |
| {HERO_BADGES} | |
| </div> | |
| """ | |
| ) | |
| with gr.Tabs(elem_classes="tab-nav"): | |
| with gr.Row(elem_classes="panel-row"): | |
| with gr.Column(elem_classes=["glass-card"]): | |
| gr.Markdown("### 🎵 Audio Input") | |
| audio_in = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Upload Audio File", | |
| show_label=True, | |
| ) | |
| gr.Markdown("**OR**") | |
| youtube_url = gr.Textbox(label="YouTube URL", placeholder="https://www.youtube.com/watch?v=...", info="Paste any YouTube URL - we'll extract high-quality audio automatically") | |
| load_btn = gr.Button("🔄 Load Audio", variant="secondary", size="sm") | |
| status_text = gr.Textbox(label="Status", interactive=False, visible=False) | |
| youtube_embed = gr.HTML(label="Video Preview", visible=False) | |
| prompt_in = gr.Textbox( | |
| label="Prompt", | |
| value="Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", | |
| placeholder="Ask a question about the audio…", | |
| lines=6, | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLE_YOUTUBE_PROMPTS, | |
| inputs=[youtube_url, prompt_in], | |
| label="🎵 Example Prompts", | |
| ) | |
| btn = gr.Button("Generate Answer", elem_classes="accent-button") | |
| with gr.Column(elem_classes=["glass-card"]): | |
| out = gr.Textbox( | |
| label="Model Response", | |
| lines=25, | |
| placeholder="Model answers will appear here with audio-informed insights…", | |
| ) | |
| load_btn.click(lambda: [None, "🔄 Loading audio...", gr.update(visible=True)], outputs=[audio_in, status_text, status_text]).then( | |
| fn=load_youtube_audio, inputs=[youtube_url], outputs=[audio_in, status_text, youtube_embed] | |
| ).then(lambda: gr.update(visible=True), outputs=[youtube_embed]) | |
| btn.click(fn=infer, inputs=[audio_in, youtube_url, prompt_in], outputs=out) | |
| gr.HTML( | |
| """ | |
| <div class="footer-note"> | |
| © 2025 NVIDIA | Powered by 🤗 Transformers + Gradio | |
| </div> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |