|
|
import argparse |
|
|
import scipy |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import os |
|
|
import os.path as osp |
|
|
from tqdm import tqdm |
|
|
from datetime import timedelta |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
from diffusers import AudioLDM2Pipeline |
|
|
|
|
|
from colossalai.utils import get_current_device, set_seed |
|
|
from javisdit.utils.inference_utils import load_prompts, get_save_path_name |
|
|
|
|
|
|
|
|
def infer_audioldm2( |
|
|
rank, world_size, pipe, all_prompts, bs: int, |
|
|
audio_length_in_s: float = 10.24, match_duration: bool = False, output_dir: str = "./output" |
|
|
): |
|
|
device = get_current_device() |
|
|
pipe = pipe.to(device) |
|
|
|
|
|
num_per_node = int(np.ceil(len(all_prompts) / world_size)) |
|
|
prompts = [] |
|
|
for i in range(rank*num_per_node, min((rank+1)*num_per_node, len(all_prompts))): |
|
|
row = all_prompts[i] |
|
|
prompts.append(row) |
|
|
|
|
|
save_paths = [ |
|
|
get_save_path_name( |
|
|
output_dir, |
|
|
sample_idx=rank*num_per_node + idx, |
|
|
) + '.wav' for idx in range(len(prompts)) |
|
|
] |
|
|
|
|
|
for i in range(0, len(prompts), bs): |
|
|
if rank == 0: |
|
|
print(f'\nProcessing {i}/{len(prompts)}') |
|
|
audios = pipe(prompts[i:i+bs], num_inference_steps=200, audio_length_in_s=audio_length_in_s).audios |
|
|
for j, audio in enumerate(audios): |
|
|
try: |
|
|
|
|
|
scipy.io.wavfile.write(save_paths[i+j], rate=16000, data=audio) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
|
|
dist.init_process_group("nccl", timeout=timedelta(hours=24)) |
|
|
|
|
|
rank = dist.get_rank() |
|
|
world_size = torch.cuda.device_count() |
|
|
torch.cuda.set_device(rank % world_size) |
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
pipe = AudioLDM2Pipeline.from_pretrained(args.model_name_or_path) |
|
|
|
|
|
prompts = load_prompts(args.prompts) |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
infer_audioldm2(rank, world_size, pipe, prompts, args.batch_size, args.audio_length_in_s, args.match_duration, args.output_dir) |
|
|
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--seed', type=int, default=1024) |
|
|
parser.add_argument('--batch_size', type=int, default=32) |
|
|
parser.add_argument('--model_name_or_path', type=str, default='../../checkpoints/audioldm2') |
|
|
parser.add_argument('--audio_length_in_s', type=float, default=2.56) |
|
|
parser.add_argument('--prompts', type=str, default='data/meta/st_prior/meta_info_fmin10_fmax1000_au_sr16000_mmtrail136k_tavgbench240k_h100.csv') |
|
|
parser.add_argument('--output_dir', type=str, default='/home/haofei/kailiu/datasets/st_prior/audio/unpaired') |
|
|
parser.add_argument('--match_duration', action='store_true') |
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args) |
|
|
|