import argparse import os import sys import torch from scipy.io.wavfile import write sys.path.append(os.path.dirname(__file__)) from speech_utils import load_U2S_model, load_condition_centroid from speech_tokenization.UVITS.text import text_to_sequence from my_synthesis.my_synthesis_for_speech_unit_sequence_recombination import get_U2S_config_checkpoint_file def main(): parser = argparse.ArgumentParser() parser.add_argument("--speech_units", type=str, required=True, help="Space-separated string of speech unit IDs.") parser.add_argument("--condition", type=str, required=True, help="Style condition for synthesis.") parser.add_argument("--output_path", type=str, required=True, help="Path to save the output .wav file.") parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on.") args = parser.parse_args() unit_type_u2s = '40ms_multilingual_8888_xujing_cosyvoice_FT' u2s_config_file, u2s_checkpoint_file = get_U2S_config_checkpoint_file(unit_type_u2s) u2s_model, u2s_hps = load_U2S_model(u2s_config_file, u2s_checkpoint_file, unit_type_u2s) u2s_model.to(args.device).eval() condition_file = "./speech_tokenization/condition_style_centroid/condition2style_centroid.txt" _, condition_embeddings = load_condition_centroid(condition_file) with torch.no_grad(): style_embedding = condition_embeddings[args.condition].to(args.device) unit_sequence_int = text_to_sequence(args.speech_units, u2s_hps.data.text_cleaners) unit_sequence_tensor = torch.LongTensor(unit_sequence_int).unsqueeze(0).to(args.device) unit_lengths = torch.LongTensor([unit_sequence_tensor.size(1)]).to(args.device) audio = u2s_model.synthesis_from_content_unit_style_embedding( unit_sequence_tensor, unit_lengths, style_embedding, noise_scale=.667, noise_scale_w=0.8, length_scale=1 )[0][0, 0].data.cpu().float().numpy() write(args.output_path, u2s_hps.data.sampling_rate, audio) print(f"Successfully saved audio to {args.output_path}") if __name__ == "__main__": main()