File size: 1,070 Bytes
7bd8b78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import argparse
import os
import sys
import torch
import json

sys.path.append(os.path.dirname(__file__))

from speech_utils import load_S2U_model
from speech_tokenization.SPIRAL_L2_BN_FSQ_CTC.my_extract_unit_for_speech.extract_unit_construct_wav_unit_text import (
    get_S2U_ckpt_config_path,
    sample_extract_unit,
)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--wav_path", type=str, required=True, help="Path to the input audio file.")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on.")
    args = parser.parse_args()

    unit_type_s2u = '40ms_multilingual_8888'
    s2u_model_name = 'SPIRAL-FSQ-CTC'
    s2u_ckpt_path, s2u_config_path = get_S2U_ckpt_config_path(unit_type_s2u)
    s2u_model = load_S2U_model(s2u_ckpt_path, s2u_config_path, s2u_model_name)
    s2u_model.to(args.device).eval()

    with torch.no_grad():
        _, reduced_unit_sequence = sample_extract_unit(args.wav_path, s2u_model)
        
    print(reduced_unit_sequence)

if __name__ == "__main__":
    main()