Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| import sys | |
| import torch | |
| from scipy.io.wavfile import write | |
| sys.path.append(os.path.dirname(__file__)) | |
| from speech_utils import load_U2S_model, load_condition_centroid | |
| from speech_tokenization.UVITS.text import text_to_sequence | |
| from my_synthesis.my_synthesis_for_speech_unit_sequence_recombination import get_U2S_config_checkpoint_file | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--speech_units", type=str, required=True, help="Space-separated string of speech unit IDs.") | |
| parser.add_argument("--condition", type=str, required=True, help="Style condition for synthesis.") | |
| parser.add_argument("--output_path", type=str, required=True, help="Path to save the output .wav file.") | |
| parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on.") | |
| args = parser.parse_args() | |
| unit_type_u2s = '40ms_multilingual_8888_xujing_cosyvoice_FT' | |
| u2s_config_file, u2s_checkpoint_file = get_U2S_config_checkpoint_file(unit_type_u2s) | |
| u2s_model, u2s_hps = load_U2S_model(u2s_config_file, u2s_checkpoint_file, unit_type_u2s) | |
| u2s_model.to(args.device).eval() | |
| condition_file = "./speech_tokenization/condition_style_centroid/condition2style_centroid.txt" | |
| _, condition_embeddings = load_condition_centroid(condition_file) | |
| with torch.no_grad(): | |
| style_embedding = condition_embeddings[args.condition].to(args.device) | |
| unit_sequence_int = text_to_sequence(args.speech_units, u2s_hps.data.text_cleaners) | |
| unit_sequence_tensor = torch.LongTensor(unit_sequence_int).unsqueeze(0).to(args.device) | |
| unit_lengths = torch.LongTensor([unit_sequence_tensor.size(1)]).to(args.device) | |
| audio = u2s_model.synthesis_from_content_unit_style_embedding( | |
| unit_sequence_tensor, unit_lengths, style_embedding, | |
| noise_scale=.667, noise_scale_w=0.8, length_scale=1 | |
| )[0][0, 0].data.cpu().float().numpy() | |
| write(args.output_path, u2s_hps.data.sampling_rate, audio) | |
| print(f"Successfully saved audio to {args.output_path}") | |
| if __name__ == "__main__": | |
| main() |