JAV-Gen / scripts /misc /auto_resume.py
kaiw7's picture
Upload folder using huggingface_hub
e490e7e verified
import os
import os.path as osp
import shutil
from glob import glob
import subprocess
import warnings
warnings.filterwarnings('ignore')
def auto_resume():
# os.system("export CUDA_VISIBLE_DEVICES=0,1,2,4")
# save_dir = './outputs' # outputs, runs
############### train VALDM2 ###############
# command = [
# "torchrun", "--standalone", "--nproc_per_node", "4",
# "scripts/vaffusion/train_valdm2.py",
# "configs/opensora-valdm2/train/stage3_iddpm.py",
# "--data-path", "data/meta/mmtrail_v01/meta_info_fmin1_fmax1000_au_sr16000_nospeech_fps24_training.csv",
# "--load-data", "data/feat/mmtrail_text_buffer_v01",
# # "--load", "PLACEHOLDER",
# ]
# model_name = 'valdm2'
############### gen audios ###############
# run_cmd = "torchrun --nproc_per_node 6 data_pipeline/st_prior/src/gen_unpaired_audios.py".split(' ')
############### train ST-Prior ###############
# command = [
# "torchrun", "--standalone", "--nproc_per_node", "4",
# "scripts/vaffusion/train_prior.py",
# # "configs/opensora-syncva/train/prior_stage2.py",
# # "--data-path", "data/meta/st_prior/meta_info_fmin10_fmax1000_au_sr16000_mmtrail136k_tavgbench240k_unpaired_audios.csv",
# # "configs/opensora-syncva/train/prior_stage1_feat.py",
# # "--data-path", "data/feat/mmtrail136k_tavgbench240k_st_prior_stage1",
# "configs/opensora-syncva/train/prior_stage2_feat.py",
# "--data-path", "data/feat/mmtrail136k_tavgbench240k_st_prior_stage2",
# # "--load", "PLACEHOLDER",
# ]
# model_name = 'stib'
############### train SyncVA - stage1 ###############
#os.system("export CUDA_VISIBLE_DEVICES=4,5,6,7")
save_dir = '/data/yikai/mocha/pr/audio/alan/projects/JavisDiT/outputs/006-Wan2_1_T2V_1_3B'
data_path = "/data/yikai/mocha/pr/audio/alan/datasets/JAV-Audio-Data/JavisDiT_train_audio_v1.csv"
command = [
"torchrun", "--standalone", "--nproc_per_node", "4",
"scripts/train.py",
"configs/wan2.1/train/stage1_audio.py",
"--output", save_dir,
"--data-path", data_path,
# "--load", "PLACEHOLDER",
]
model_name = 'Wan2_1'
############### train SyncVA - stage2 ###############
# # os.system("export CUDA_VISIBLE_DEVICES=3,4,5,7")
# save_dir = './runs/debug'
# command = [
# "torchrun", "--standalone", "--nproc_per_node", "4",
# "scripts/vaffusion/train_syncva.py",
# # "configs/opensora-syncva/train/syncva_stage2_bench.py",
# "configs/opensora-syncva/train2/syncva_stage2_plain.py",
# "--output", save_dir,
# "--data-path", "data/meta/syncva_train_v2.csv",
# # "--data-path", "data/meta/st_prior/meta_info_fmin10_fmax1000_au_sr16000_mmtrail136k_tavgbench240k_h100.csv",
# # "--load-data", "data/feat/mmtrail136k_tavgbench240k_text_buffer",
# ]
# command = [
# "torchrun", "--standalone", "--nproc_per_node", "4",
# "scripts/vaffusion/train_syncva.py",
# "configs/opensora-syncva/train/syncva_stage2_feat.py",
# "--output", save_dir,
# "--data-path", "data/feat/mmtrail136k_tavgbench240k_va_feat_half",
# # "--load", "PLACEHOLDER",
# ]
# command = [
# "torchrun", "--standalone", "--nproc_per_node", "2",
# "scripts/vaffusion/train_syncva.py",
# "configs/opensora-syncva/train/syncva_stage2_bench2.py",
# "--output", save_dir,
# "--data-path", "data/meta/bench_landscape_train.csv",
# ]
# model_name = 'VASTDiT'
# save_dir = '/mnt/workspace/checkpoint/audio_video'
# data_path = '/mnt/HithinkOmniSSD/user_workspace/liukai4/datasets/JavisDiT/train/video/TAVGBench_train_fps16_sft_330k_ths.csv'
# command = [
# "torchrun", "--standalone", "--nproc_per_node", "8",
# "scripts/train.py",
# "configs/wan2.1/train/stage2_audio_video.py",
# "--output", save_dir,
# "--data-path", data_path,
# # "--load", "PLACEHOLDER",
# ]
# model_name = 'Wan2_1'
while True:
ckpt_dirs = sorted(glob(f'{save_dir}/epoch*'), key=lambda p: int(p.split('_step')[-1]))
if len(ckpt_dirs):
run_cmd = command + ["--load", ckpt_dirs[-1]]
# if len(ckpt_dirs) == 1:
# run_cmd += ["--start-from-scratch"]
else:
run_cmd = command
result = subprocess.run(run_cmd, stderr=subprocess.PIPE, text=True) # , stdout=subprocess.PIPE
# import pdb; pdb.set_trace()
print(result.stderr)
if result.returncode == 0:
break
elif len(ckpt_dirs) and 'FileNotFoundError: [Errno 2] No such file or directory:' in result.stderr:
# failed to save previous checkpoints
shutil.rmtree(ckpt_dirs[-1])
if __name__ == '__main__':
auto_resume()