import argparse import scipy import pandas as pd import numpy as np import os import os.path as osp from tqdm import tqdm from datetime import timedelta import torch import torch.distributed as dist from diffusers import AudioLDM2Pipeline from colossalai.utils import get_current_device, set_seed from javisdit.utils.inference_utils import load_prompts, get_save_path_name def infer_audioldm2( rank, world_size, pipe, all_prompts, bs: int, audio_length_in_s: float = 10.24, match_duration: bool = False, output_dir: str = "./output" ): device = get_current_device() pipe = pipe.to(device) num_per_node = int(np.ceil(len(all_prompts) / world_size)) prompts = [] for i in range(rank*num_per_node, min((rank+1)*num_per_node, len(all_prompts))): row = all_prompts[i] prompts.append(row) save_paths = [ get_save_path_name( output_dir, sample_idx=rank*num_per_node + idx, ) + '.wav' for idx in range(len(prompts)) ] for i in range(0, len(prompts), bs): if rank == 0: print(f'\nProcessing {i}/{len(prompts)}') audios = pipe(prompts[i:i+bs], num_inference_steps=200, audio_length_in_s=audio_length_in_s).audios for j, audio in enumerate(audios): try: # audio padding scipy.io.wavfile.write(save_paths[i+j], rate=16000, data=audio) except Exception as e: print(e) # data.iloc[i+j, 'unpaired_audio_path'] = "" def main(args): ## prepare data and model dist.init_process_group("nccl", timeout=timedelta(hours=24)) rank = dist.get_rank() world_size = torch.cuda.device_count() torch.cuda.set_device(rank % world_size) set_seed(args.seed) pipe = AudioLDM2Pipeline.from_pretrained(args.model_name_or_path) prompts = load_prompts(args.prompts) os.makedirs(args.output_dir, exist_ok=True) infer_audioldm2(rank, world_size, pipe, prompts, args.batch_size, args.audio_length_in_s, args.match_duration, args.output_dir) dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=1024) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--model_name_or_path', type=str, default='../../checkpoints/audioldm2') parser.add_argument('--audio_length_in_s', type=float, default=2.56) parser.add_argument('--prompts', type=str, default='data/meta/st_prior/meta_info_fmin10_fmax1000_au_sr16000_mmtrail136k_tavgbench240k_h100.csv') parser.add_argument('--output_dir', type=str, default='/home/haofei/kailiu/datasets/st_prior/audio/unpaired') parser.add_argument('--match_duration', action='store_true') args = parser.parse_args() main(args) # torchrun --nproc_per_node=4 xx.py