File size: 2,886 Bytes
e490e7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import scipy
import pandas as pd
import numpy as np
import os
import os.path as osp
from tqdm import tqdm
from datetime import timedelta

import torch
import torch.distributed as dist

from diffusers import AudioLDM2Pipeline

from colossalai.utils import get_current_device, set_seed
from javisdit.utils.inference_utils import load_prompts, get_save_path_name


def infer_audioldm2(
    rank, world_size, pipe, all_prompts, bs: int, 
    audio_length_in_s: float = 10.24, match_duration: bool = False, output_dir: str = "./output"
):
    device = get_current_device()
    pipe = pipe.to(device)

    num_per_node = int(np.ceil(len(all_prompts) / world_size))
    prompts = []
    for i in range(rank*num_per_node, min((rank+1)*num_per_node, len(all_prompts))):
        row = all_prompts[i]
        prompts.append(row)
    
    save_paths = [
        get_save_path_name(
            output_dir,
            sample_idx=rank*num_per_node + idx,
        ) + '.wav' for idx in range(len(prompts))
    ]

    for i in range(0, len(prompts), bs):
        if rank == 0:
            print(f'\nProcessing {i}/{len(prompts)}')
        audios = pipe(prompts[i:i+bs], num_inference_steps=200, audio_length_in_s=audio_length_in_s).audios
        for j, audio in enumerate(audios):
            try:
                # audio padding
                scipy.io.wavfile.write(save_paths[i+j], rate=16000, data=audio)
            except Exception as e:
                print(e)
                # data.iloc[i+j, 'unpaired_audio_path'] = ""

def main(args):
    ## prepare data and model
    dist.init_process_group("nccl", timeout=timedelta(hours=24))

    rank = dist.get_rank()
    world_size = torch.cuda.device_count()
    torch.cuda.set_device(rank % world_size)

    set_seed(args.seed)

    pipe = AudioLDM2Pipeline.from_pretrained(args.model_name_or_path)

    prompts = load_prompts(args.prompts)
    os.makedirs(args.output_dir, exist_ok=True)

    infer_audioldm2(rank, world_size, pipe, prompts, args.batch_size, args.audio_length_in_s, args.match_duration, args.output_dir)

    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=1024)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--model_name_or_path', type=str, default='../../checkpoints/audioldm2')
    parser.add_argument('--audio_length_in_s', type=float, default=2.56)
    parser.add_argument('--prompts', type=str, default='data/meta/st_prior/meta_info_fmin10_fmax1000_au_sr16000_mmtrail136k_tavgbench240k_h100.csv')
    parser.add_argument('--output_dir', type=str, default='/home/haofei/kailiu/datasets/st_prior/audio/unpaired')
    parser.add_argument('--match_duration', action='store_true')
    args = parser.parse_args()

    main(args)
    # torchrun --nproc_per_node=4 xx.py