JAV-Gen / scripts /misc /inference_audioldm2.py
kaiw7's picture
Upload folder using huggingface_hub
e490e7e verified
import argparse
import scipy
import pandas as pd
import numpy as np
import os
import os.path as osp
from tqdm import tqdm
from datetime import timedelta
import torch
import torch.distributed as dist
from diffusers import AudioLDM2Pipeline
from colossalai.utils import get_current_device, set_seed
from javisdit.utils.inference_utils import load_prompts, get_save_path_name
def infer_audioldm2(
rank, world_size, pipe, all_prompts, bs: int,
audio_length_in_s: float = 10.24, match_duration: bool = False, output_dir: str = "./output"
):
device = get_current_device()
pipe = pipe.to(device)
num_per_node = int(np.ceil(len(all_prompts) / world_size))
prompts = []
for i in range(rank*num_per_node, min((rank+1)*num_per_node, len(all_prompts))):
row = all_prompts[i]
prompts.append(row)
save_paths = [
get_save_path_name(
output_dir,
sample_idx=rank*num_per_node + idx,
) + '.wav' for idx in range(len(prompts))
]
for i in range(0, len(prompts), bs):
if rank == 0:
print(f'\nProcessing {i}/{len(prompts)}')
audios = pipe(prompts[i:i+bs], num_inference_steps=200, audio_length_in_s=audio_length_in_s).audios
for j, audio in enumerate(audios):
try:
# audio padding
scipy.io.wavfile.write(save_paths[i+j], rate=16000, data=audio)
except Exception as e:
print(e)
# data.iloc[i+j, 'unpaired_audio_path'] = ""
def main(args):
## prepare data and model
dist.init_process_group("nccl", timeout=timedelta(hours=24))
rank = dist.get_rank()
world_size = torch.cuda.device_count()
torch.cuda.set_device(rank % world_size)
set_seed(args.seed)
pipe = AudioLDM2Pipeline.from_pretrained(args.model_name_or_path)
prompts = load_prompts(args.prompts)
os.makedirs(args.output_dir, exist_ok=True)
infer_audioldm2(rank, world_size, pipe, prompts, args.batch_size, args.audio_length_in_s, args.match_duration, args.output_dir)
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1024)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--model_name_or_path', type=str, default='../../checkpoints/audioldm2')
parser.add_argument('--audio_length_in_s', type=float, default=2.56)
parser.add_argument('--prompts', type=str, default='data/meta/st_prior/meta_info_fmin10_fmax1000_au_sr16000_mmtrail136k_tavgbench240k_h100.csv')
parser.add_argument('--output_dir', type=str, default='/home/haofei/kailiu/datasets/st_prior/audio/unpaired')
parser.add_argument('--match_duration', action='store_true')
args = parser.parse_args()
main(args)
# torchrun --nproc_per_node=4 xx.py