File size: 2,139 Bytes
7bd8b78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import os
import sys
import torch
from scipy.io.wavfile import write

sys.path.append(os.path.dirname(__file__))

from speech_utils import load_U2S_model, load_condition_centroid
from speech_tokenization.UVITS.text import text_to_sequence
from my_synthesis.my_synthesis_for_speech_unit_sequence_recombination import get_U2S_config_checkpoint_file

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--speech_units", type=str, required=True, help="Space-separated string of speech unit IDs.")
    parser.add_argument("--condition", type=str, required=True, help="Style condition for synthesis.")
    parser.add_argument("--output_path", type=str, required=True, help="Path to save the output .wav file.")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on.")
    args = parser.parse_args()

    unit_type_u2s = '40ms_multilingual_8888_xujing_cosyvoice_FT'
    u2s_config_file, u2s_checkpoint_file = get_U2S_config_checkpoint_file(unit_type_u2s)
    u2s_model, u2s_hps = load_U2S_model(u2s_config_file, u2s_checkpoint_file, unit_type_u2s)
    u2s_model.to(args.device).eval()

    condition_file = "./speech_tokenization/condition_style_centroid/condition2style_centroid.txt"
    _, condition_embeddings = load_condition_centroid(condition_file)

    with torch.no_grad():
        style_embedding = condition_embeddings[args.condition].to(args.device)
        
        unit_sequence_int = text_to_sequence(args.speech_units, u2s_hps.data.text_cleaners)
        unit_sequence_tensor = torch.LongTensor(unit_sequence_int).unsqueeze(0).to(args.device)
        unit_lengths = torch.LongTensor([unit_sequence_tensor.size(1)]).to(args.device)

        audio = u2s_model.synthesis_from_content_unit_style_embedding(
            unit_sequence_tensor, unit_lengths, style_embedding,
            noise_scale=.667, noise_scale_w=0.8, length_scale=1
        )[0][0, 0].data.cpu().float().numpy()

        write(args.output_path, u2s_hps.data.sampling_rate, audio)
    
    print(f"Successfully saved audio to {args.output_path}")

if __name__ == "__main__":
    main()