Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| import librosa | |
| import gradio as gr | |
| from snac import SNAC | |
| import re | |
| orpheus_model_id = 'NandemoGHS/Galgame-Orpheus-3B' | |
| tokenizer = AutoTokenizer.from_pretrained(orpheus_model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| orpheus_model_id, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| ) | |
| model.eval().cuda() | |
| snac_model_id = 'hubertsiuzdak/snac_24khz' | |
| snac_model = SNAC.from_pretrained(snac_model_id) | |
| snac_model.eval().cuda() | |
| whisper_turbo_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-large-v3-turbo", | |
| torch_dtype=torch.float16, | |
| device='cuda', | |
| ) | |
| SOT_ID = 128000 # Start of Text (Not used) | |
| EOT_ID = 128009 # End of Text | |
| SOS_ID = 128257 # Start of Speech | |
| EOS_ID = 128258 # End of Speech | |
| SOH_ID = 128259 # Start of Human | |
| EOH_ID = 128260 # End of Human | |
| SOA_ID = 128261 # Start of AI | |
| EOA_ID = 128262 # End of AI | |
| REPLACE_MAP: dict[str, str] = { | |
| r"\t": "", | |
| r"\[n\]": "", | |
| r" ": "", | |
| r"γ": "", | |
| r"[;βΌββγγβͺβ«ξΎβ β‘β’β£β€β₯]": "", | |
| r"[\u02d7\u2010-\u2015\u2043\u2212\u23af\u23e4\u2500\u2501\u2e3a\u2e3b]": "", | |
| r"[\uff5e\u301C]": "γΌ", | |
| r"οΌ": "?", | |
| r"οΌ": "!", | |
| r"[ββ―γ]": "β", | |
| r"β₯": "β‘", | |
| } | |
| FULLWIDTH_ALPHA_TO_HALFWIDTH = str.maketrans( | |
| { | |
| chr(full): chr(half) | |
| for full, half in zip( | |
| list(range(0xFF21, 0xFF3B)) + list(range(0xFF41, 0xFF5B)), | |
| list(range(0x41, 0x5B)) + list(range(0x61, 0x7B)), | |
| ) | |
| } | |
| ) | |
| HALFWIDTH_KATAKANA_TO_FULLWIDTH = str.maketrans( | |
| { | |
| chr(half): chr(full) | |
| for half, full in zip(range(0xFF61, 0xFF9F), range(0x30A1, 0x30FB)) | |
| } | |
| ) | |
| FULLWIDTH_DIGITS_TO_HALFWIDTH = str.maketrans( | |
| { | |
| chr(full): chr(half) | |
| for full, half in zip(range(0xFF10, 0xFF1A), range(0x30, 0x3A)) | |
| } | |
| ) | |
| INVALID_PATTERN = re.compile( | |
| r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" | |
| r"\u0041-\u005A\u0061-\u007A" | |
| r"\u0030-\u0039" | |
| r"γγ!?β¦βͺβ‘β]" | |
| ) | |
| def normalize(text: str) -> str: | |
| for pattern, replacement in REPLACE_MAP.items(): | |
| text = re.sub(pattern, replacement, text) | |
| text = text.translate(FULLWIDTH_ALPHA_TO_HALFWIDTH) | |
| text = text.translate(FULLWIDTH_DIGITS_TO_HALFWIDTH) | |
| text = text.translate(HALFWIDTH_KATAKANA_TO_FULLWIDTH) | |
| text = re.sub(r"β¦{3,}", "β¦β¦", text) | |
| def replace_special_chars(match): | |
| seq = match.group(0) | |
| return seq[0] if len(set(seq)) == 1 else seq[0] + seq[-1] | |
| return text | |
| def tokenize_audio(waveform): | |
| waveform = waveform.unsqueeze(0) | |
| with torch.inference_mode(): | |
| codes = snac_model.encode(waveform) | |
| all_codes = [] | |
| for i in range(codes[0].shape[1]): | |
| all_codes.append(codes[0][0][i].item()+128266) | |
| all_codes.append(codes[1][0][2*i].item()+128266+4096) | |
| all_codes.append(codes[2][0][4*i].item()+128266+(2*4096)) | |
| all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096)) | |
| all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096)) | |
| all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096)) | |
| all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096)) | |
| return all_codes | |
| def redistribute_codes(code_list): | |
| new_length = (len(code_list) // 7) * 7 | |
| if new_length == 0: | |
| return None | |
| code_list = code_list[:new_length] | |
| layer_1 = [] | |
| layer_2 = [] | |
| layer_3 = [] | |
| for i in range((len(code_list)+1)//7): | |
| layer_1.append(code_list[7*i]) | |
| layer_2.append(code_list[7*i+1]-4096) | |
| layer_3.append(code_list[7*i+2]-(2*4096)) | |
| layer_3.append(code_list[7*i+3]-(3*4096)) | |
| layer_2.append(code_list[7*i+4]-(4*4096)) | |
| layer_3.append(code_list[7*i+5]-(5*4096)) | |
| layer_3.append(code_list[7*i+6]-(6*4096)) | |
| codes = [ | |
| torch.tensor(layer_1).unsqueeze(0), | |
| torch.tensor(layer_2).unsqueeze(0), | |
| torch.tensor(layer_3).unsqueeze(0) | |
| ] | |
| print(codes) | |
| codes = [c.cuda() for c in codes] | |
| with torch.no_grad(): | |
| audio_hat = snac_model.decode(codes) | |
| return audio_hat | |
| def infer(sample_audio_path, target_text, temperature, top_p, repetition_penalty, progress=gr.Progress()): | |
| if not target_text or not target_text.strip(): | |
| gr.Warning("Please input text to generate audio.") | |
| return None, None | |
| if len(target_text) > 300: | |
| gr.Warning("Text is too long. Please keep it under 300 characters.") | |
| target_text = target_text[:300] | |
| target_text = normalize(target_text) | |
| with torch.no_grad(): | |
| if sample_audio_path: | |
| progress(0, 'Loading and trimming audio...') | |
| audio_array, sample_rate = librosa.load(sample_audio_path, sr=24000) | |
| if len(audio_array) / sample_rate > 15: | |
| gr.Warning("Trimming audio to first 15secs.") | |
| num_samples_to_keep = int(sample_rate * 15) | |
| audio_array = audio_array[:num_samples_to_keep] | |
| prompt_wav = torch.from_numpy(audio_array).unsqueeze(0) | |
| prompt_wav = prompt_wav.to(dtype=torch.float32) | |
| progress(0.2, 'Transcribing reference audio...') | |
| prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip() | |
| progress(0.4, 'Transcribed! Encoding audio...') | |
| # Encode the prompt wav | |
| snac_dev = next(snac_model.parameters()).device | |
| voice_tokens = tokenize_audio(prompt_wav.to(device=snac_dev)) | |
| ref_text_ids = tokenizer(prompt_text, return_tensors="pt").input_ids[0].tolist() | |
| prompt_ids = ( | |
| [SOH_ID] | |
| + ref_text_ids | |
| + [EOT_ID] | |
| + [EOH_ID] | |
| + [SOA_ID] | |
| + [SOS_ID] | |
| + voice_tokens | |
| + [EOS_ID] | |
| + [EOA_ID] | |
| ) | |
| else: | |
| prompt_ids = [] | |
| progress(0.6, "Generating audio...") | |
| target_text_ids = tokenizer(target_text, return_tensors="pt").input_ids[0].tolist() | |
| prompt_ids.extend([SOH_ID]) | |
| prompt_ids.extend(target_text_ids) | |
| prompt_ids.extend([EOT_ID]) | |
| prompt_ids.extend([EOH_ID]) | |
| prompt_ids.extend([SOA_ID]) | |
| prompt_ids.extend([SOS_ID]) | |
| print(prompt_ids) | |
| input_ids = torch.tensor([prompt_ids], dtype=torch.int64).cuda() | |
| # Generate the speech autoregressively | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=2048, | |
| eos_token_id=EOS_ID, | |
| do_sample=True, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| generated_ids = outputs[0].tolist() | |
| print(generated_ids) | |
| progress(0.8, "Decoding generated audio...") | |
| try: | |
| last_sos_idx = len(generated_ids) - 1 - generated_ids[::-1].index(SOS_ID) | |
| speech_tokens = generated_ids[last_sos_idx + 1:] | |
| except ValueError: | |
| gr.Error("Audio generation failed: Could not find end of header token.") | |
| return None, None | |
| if EOS_ID in speech_tokens: | |
| speech_tokens = speech_tokens[:speech_tokens.index(EOS_ID)] | |
| if not speech_tokens: | |
| gr.Error("Audio generation failed: No speech tokens were generated.") | |
| return None, None | |
| base_offset = 128266 | |
| adjusted_tokens = [token - base_offset for token in speech_tokens if token >= base_offset] | |
| gen_wav_tensor = redistribute_codes(adjusted_tokens) | |
| if gen_wav_tensor is None: | |
| gr.Error("Audio decoding failed.") | |
| return None, None | |
| gen_wav = gen_wav_tensor.cpu().squeeze() | |
| progress(1, 'Synthesized!') | |
| return (24000, gen_wav.numpy()) | |
| with gr.Blocks() as app_tts: | |
| gr.Markdown("# Galgame Orpheus 3B") | |
| ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") | |
| gen_text_input = gr.Textbox(label="Text to Generate", lines=10) | |
| with gr.Row(): | |
| temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.05, label="Temperature") | |
| top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Top-p") | |
| repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=1.5, value=1.1, step=0.05, label="Repetition Penalty") | |
| generate_btn = gr.Button("Synthesize", variant="primary") | |
| audio_output = gr.Audio(label="Synthesized Audio") | |
| generate_btn.click( | |
| infer, | |
| inputs=[ | |
| ref_audio_input, | |
| gen_text_input, | |
| temperature_slider, | |
| top_p_slider, | |
| repetition_penalty_slider, | |
| ], | |
| outputs=[audio_output], | |
| ) | |
| with gr.Blocks() as app_credits: | |
| gr.Markdown(""" | |
| # Credits | |
| * [canopyai](https://github.com/canopyai) for the original [repo](https://github.com/canopyai/Orpheus-TTS) | |
| * [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) | |
| * [SunderAli17](https://huggingface.co/SunderAli17) for the [gradio demo code](https://huggingface.co/spaces/SunderAli17/llasa-3b-tts) | |
| """) | |
| with gr.Blocks() as app: | |
| gr.Markdown( | |
| """ | |
| # Galgame Orpheus 3B | |
| This is a local web UI for Galgame Orpheus 3B TTS model. You can check out the model [here](https://huggingface.co/NandemoGHS/Galgame-Orpheus-3B). | |
| The model is fine-tuned by Japanese audio data. | |
| If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. | |
| """ | |
| ) | |
| gr.TabbedInterface([app_tts], ["TTS"]) | |
| app.launch() |