import argparse import os import sys import torch import json sys.path.append(os.path.dirname(__file__)) from speech_utils import load_S2U_model from speech_tokenization.SPIRAL_L2_BN_FSQ_CTC.my_extract_unit_for_speech.extract_unit_construct_wav_unit_text import ( get_S2U_ckpt_config_path, sample_extract_unit, ) def main(): parser = argparse.ArgumentParser() parser.add_argument("--wav_path", type=str, required=True, help="Path to the input audio file.") parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on.") args = parser.parse_args() unit_type_s2u = '40ms_multilingual_8888' s2u_model_name = 'SPIRAL-FSQ-CTC' s2u_ckpt_path, s2u_config_path = get_S2U_ckpt_config_path(unit_type_s2u) s2u_model = load_S2U_model(s2u_ckpt_path, s2u_config_path, s2u_model_name) s2u_model.to(args.device).eval() with torch.no_grad(): _, reduced_unit_sequence = sample_extract_unit(args.wav_path, s2u_model) print(reduced_unit_sequence) if __name__ == "__main__": main()