Spaces:
Sleeping
Sleeping
| import os | |
| import spaces | |
| import pickle | |
| import subprocess | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| from dataclasses import asdict | |
| from transformers import T5Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| from time import time_ns | |
| from uuid import uuid4 | |
| from transformer_model import Transformer | |
| from pyharp.core import ModelCard, build_endpoint | |
| from pyharp.labels import LabelList | |
| # Model/artifacts from HF Hub | |
| REPO_ID = "amaai-lab/text2midi" | |
| MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin") | |
| TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl") | |
| # Optional, only if you later add WAV preview: | |
| SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2") | |
| # (Optional) MIDI -> WAV | |
| def save_wav(midi_path: str) -> str: | |
| directory = os.path.dirname(midi_path) or "." | |
| stem = os.path.splitext(os.path.basename(midi_path))[0] | |
| midi_filepath = os.path.join(directory, f"{stem}.mid") | |
| wav_filepath = os.path.join(directory, f"{stem}.wav") | |
| cmd = ( | |
| f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell " | |
| f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null" | |
| ) | |
| subprocess.run(cmd, shell=True, check=False) | |
| return wav_filepath | |
| # Helpers | |
| def _unique_path(ext: str) -> str: | |
| """Create a unique file path in /tmp to avoid naming collisions.""" | |
| return os.path.join("/tmp", f"t2m_{time_ns()}_{uuid4().hex[:8]}{ext}") | |
| # Core Text -> MIDI | |
| def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load REMI vocab/tokenizer (pickle dict used by the provided model) | |
| with open(TOKENIZER_PATH, "rb") as f: | |
| r_tokenizer = pickle.load(f) | |
| vocab_size = len(r_tokenizer) | |
| model = Transformer( | |
| vocab_size, # vocab size | |
| 768, # d_model | |
| 8, # nhead | |
| 2048, # dim_feedforward | |
| 18, # nlayers | |
| 1024, # max_seq_len | |
| False, # use_rotary | |
| 8, # rotary_dim | |
| device=device # device | |
| ) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) | |
| model.eval() | |
| tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
| input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device) | |
| attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device) | |
| with torch.no_grad(): | |
| output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature) | |
| output_list = output[0].tolist() | |
| generated_midi = r_tokenizer.decode(output_list) | |
| midi_path = _unique_path(".mid") | |
| generated_midi.dump_midi(midi_path) | |
| return midi_path | |
| # HARP process function | |
| # Return JSON first, MIDI second | |
| def process_fn(prompt: str, temperature: float, max_length: int): | |
| try: | |
| midi_path = generate_midi(prompt, float(temperature), int(max_length)) | |
| labels = LabelList() # add MidiLabel entries here if you have metadata | |
| return asdict(labels), midi_path | |
| except Exception as e: | |
| # On error: return JSON with error message, and no file | |
| return {"message": f"Error: {e}"}, None | |
| # HARP Model Card | |
| model_card = ModelCard( | |
| name="Text2MIDI (HARP)", | |
| description="Generate MIDI from a text prompt using a transformer decoder conditioned on T5 embeddings.", | |
| author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans", | |
| tags=["text-to-music", "midi", "generation"] | |
| ) | |
| # Gradio + HARP UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎶 text2midi") | |
| # Inputs | |
| prompt_in = gr.Textbox(label="Prompt").harp_required(True) | |
| temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Temperature", interactive=True) | |
| maxlen_in = gr.Slider(minimum=500, maximum=1500, step=100, value=500, label="Max Length") | |
| # Outputs (JSON FIRST for HARP, then MIDI) | |
| labels_out = gr.JSON(label="Labels / Metadata") | |
| midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath") | |
| # Build HARP endpoint | |
| _ = build_endpoint( | |
| model_card=model_card, | |
| input_components=[prompt_in, temperature_in, maxlen_in], | |
| output_components=[labels_out, midi_out], # JSON first | |
| process_fn=process_fn | |
| ) | |
| # Launch App | |
| demo.launch(share=True, show_error=True, debug=True) |