Spaces:
Sleeping
Sleeping
File size: 4,647 Bytes
4504581 07e756d 4504581 07e756d 4504581 07e756d 4504581 8ff86f9 4504581 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import spaces
import pickle
import subprocess
import torch
import torch.nn as nn
import gradio as gr
from dataclasses import asdict
from transformers import T5Tokenizer
from huggingface_hub import hf_hub_download
from time import time_ns
from uuid import uuid4
from transformer_model import Transformer
from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList
# Model/artifacts from HF Hub
REPO_ID = "amaai-lab/text2midi"
MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl")
# Optional, only if you later add WAV preview:
SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2")
# (Optional) MIDI -> WAV
def save_wav(midi_path: str) -> str:
directory = os.path.dirname(midi_path) or "."
stem = os.path.splitext(os.path.basename(midi_path))[0]
midi_filepath = os.path.join(directory, f"{stem}.mid")
wav_filepath = os.path.join(directory, f"{stem}.wav")
cmd = (
f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell "
f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null"
)
subprocess.run(cmd, shell=True, check=False)
return wav_filepath
# Helpers
def _unique_path(ext: str) -> str:
"""Create a unique file path in /tmp to avoid naming collisions."""
return os.path.join("/tmp", f"t2m_{time_ns()}_{uuid4().hex[:8]}{ext}")
# Core Text -> MIDI
def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load REMI vocab/tokenizer (pickle dict used by the provided model)
with open(TOKENIZER_PATH, "rb") as f:
r_tokenizer = pickle.load(f)
vocab_size = len(r_tokenizer)
model = Transformer(
vocab_size, # vocab size
768, # d_model
8, # nhead
2048, # dim_feedforward
18, # nlayers
1024, # max_seq_len
False, # use_rotary
8, # rotary_dim
device=device # device
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device)
attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device)
with torch.no_grad():
output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)
output_list = output[0].tolist()
generated_midi = r_tokenizer.decode(output_list)
midi_path = _unique_path(".mid")
generated_midi.dump_midi(midi_path)
return midi_path
# HARP process function
# Return JSON first, MIDI second
@spaces.GPU(duration=120)
def process_fn(prompt: str, temperature: float, max_length: int):
try:
midi_path = generate_midi(prompt, float(temperature), int(max_length))
labels = LabelList() # add MidiLabel entries here if you have metadata
return asdict(labels), midi_path
except Exception as e:
# On error: return JSON with error message, and no file
return {"message": f"Error: {e}"}, None
# HARP Model Card
model_card = ModelCard(
name="Text2MIDI (HARP)",
description="Generate MIDI from a text prompt using a transformer decoder conditioned on T5 embeddings.",
author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans",
tags=["text-to-music", "midi", "generation"]
)
# Gradio + HARP UI
with gr.Blocks() as demo:
gr.Markdown("## 🎶 text2midi")
# Inputs
prompt_in = gr.Textbox(label="Prompt").harp_required(True)
temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Temperature", interactive=True)
maxlen_in = gr.Slider(minimum=500, maximum=1500, step=100, value=500, label="Max Length")
# Outputs (JSON FIRST for HARP, then MIDI)
labels_out = gr.JSON(label="Labels / Metadata")
midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath")
# Build HARP endpoint
_ = build_endpoint(
model_card=model_card,
input_components=[prompt_in, temperature_in, maxlen_in],
output_components=[labels_out, midi_out], # JSON first
process_fn=process_fn
)
# Launch App
demo.launch(share=True, show_error=True, debug=True) |