import os import os.path as osp import shutil from glob import glob import subprocess import warnings warnings.filterwarnings('ignore') def auto_resume(): # os.system("export CUDA_VISIBLE_DEVICES=0,1,2,4") # save_dir = './outputs' # outputs, runs ############### train VALDM2 ############### # command = [ # "torchrun", "--standalone", "--nproc_per_node", "4", # "scripts/vaffusion/train_valdm2.py", # "configs/opensora-valdm2/train/stage3_iddpm.py", # "--data-path", "data/meta/mmtrail_v01/meta_info_fmin1_fmax1000_au_sr16000_nospeech_fps24_training.csv", # "--load-data", "data/feat/mmtrail_text_buffer_v01", # # "--load", "PLACEHOLDER", # ] # model_name = 'valdm2' ############### gen audios ############### # run_cmd = "torchrun --nproc_per_node 6 data_pipeline/st_prior/src/gen_unpaired_audios.py".split(' ') ############### train ST-Prior ############### # command = [ # "torchrun", "--standalone", "--nproc_per_node", "4", # "scripts/vaffusion/train_prior.py", # # "configs/opensora-syncva/train/prior_stage2.py", # # "--data-path", "data/meta/st_prior/meta_info_fmin10_fmax1000_au_sr16000_mmtrail136k_tavgbench240k_unpaired_audios.csv", # # "configs/opensora-syncva/train/prior_stage1_feat.py", # # "--data-path", "data/feat/mmtrail136k_tavgbench240k_st_prior_stage1", # "configs/opensora-syncva/train/prior_stage2_feat.py", # "--data-path", "data/feat/mmtrail136k_tavgbench240k_st_prior_stage2", # # "--load", "PLACEHOLDER", # ] # model_name = 'stib' ############### train SyncVA - stage1 ############### #os.system("export CUDA_VISIBLE_DEVICES=4,5,6,7") save_dir = '/data/yikai/mocha/pr/audio/alan/projects/JavisDiT/outputs/006-Wan2_1_T2V_1_3B' data_path = "/data/yikai/mocha/pr/audio/alan/datasets/JAV-Audio-Data/JavisDiT_train_audio_v1.csv" command = [ "torchrun", "--standalone", "--nproc_per_node", "4", "scripts/train.py", "configs/wan2.1/train/stage1_audio.py", "--output", save_dir, "--data-path", data_path, # "--load", "PLACEHOLDER", ] model_name = 'Wan2_1' ############### train SyncVA - stage2 ############### # # os.system("export CUDA_VISIBLE_DEVICES=3,4,5,7") # save_dir = './runs/debug' # command = [ # "torchrun", "--standalone", "--nproc_per_node", "4", # "scripts/vaffusion/train_syncva.py", # # "configs/opensora-syncva/train/syncva_stage2_bench.py", # "configs/opensora-syncva/train2/syncva_stage2_plain.py", # "--output", save_dir, # "--data-path", "data/meta/syncva_train_v2.csv", # # "--data-path", "data/meta/st_prior/meta_info_fmin10_fmax1000_au_sr16000_mmtrail136k_tavgbench240k_h100.csv", # # "--load-data", "data/feat/mmtrail136k_tavgbench240k_text_buffer", # ] # command = [ # "torchrun", "--standalone", "--nproc_per_node", "4", # "scripts/vaffusion/train_syncva.py", # "configs/opensora-syncva/train/syncva_stage2_feat.py", # "--output", save_dir, # "--data-path", "data/feat/mmtrail136k_tavgbench240k_va_feat_half", # # "--load", "PLACEHOLDER", # ] # command = [ # "torchrun", "--standalone", "--nproc_per_node", "2", # "scripts/vaffusion/train_syncva.py", # "configs/opensora-syncva/train/syncva_stage2_bench2.py", # "--output", save_dir, # "--data-path", "data/meta/bench_landscape_train.csv", # ] # model_name = 'VASTDiT' # save_dir = '/mnt/workspace/checkpoint/audio_video' # data_path = '/mnt/HithinkOmniSSD/user_workspace/liukai4/datasets/JavisDiT/train/video/TAVGBench_train_fps16_sft_330k_ths.csv' # command = [ # "torchrun", "--standalone", "--nproc_per_node", "8", # "scripts/train.py", # "configs/wan2.1/train/stage2_audio_video.py", # "--output", save_dir, # "--data-path", data_path, # # "--load", "PLACEHOLDER", # ] # model_name = 'Wan2_1' while True: ckpt_dirs = sorted(glob(f'{save_dir}/epoch*'), key=lambda p: int(p.split('_step')[-1])) if len(ckpt_dirs): run_cmd = command + ["--load", ckpt_dirs[-1]] # if len(ckpt_dirs) == 1: # run_cmd += ["--start-from-scratch"] else: run_cmd = command result = subprocess.run(run_cmd, stderr=subprocess.PIPE, text=True) # , stdout=subprocess.PIPE # import pdb; pdb.set_trace() print(result.stderr) if result.returncode == 0: break elif len(ckpt_dirs) and 'FileNotFoundError: [Errno 2] No such file or directory:' in result.stderr: # failed to save previous checkpoints shutil.rmtree(ckpt_dirs[-1]) if __name__ == '__main__': auto_resume()