JAV-Gen / scripts /misc /train_prior.py
kaiw7's picture
Upload folder using huggingface_hub
e490e7e verified
import os
import random
from contextlib import nullcontext
from datetime import timedelta
from pprint import pformat
from glob import glob
import shutil
import re
import pdb
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.distributed as dist
import wandb
import colossalai
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
from javisdit.acceleration.checkpoint import set_grad_checkpoint
from javisdit.acceleration.parallel_states import get_data_parallel_group
from javisdit.datasets.dataloader import prepare_dataloader
from javisdit.registry import DATASETS, MODELS, build_module
from javisdit.utils.ckpt_utils import load, save
from javisdit.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
from javisdit.utils.misc import (
Timer,
all_reduce_mean,
create_logger,
create_tensorboard_writer,
format_numel_str,
get_model_numel,
to_torch_dtype,
)
from javisdit.utils.train_utils import create_colossalai_plugin
def main():
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=True)
record_time = cfg.get("record_time", False)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
if cfg.get('host'):
colossalai.launch_from_openmpi(cfg.host, cfg.port)
else:
dist.init_process_group(backend="nccl", timeout=timedelta(minutes=5))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(cfg.get("seed", 1024))
coordinator = DistCoordinator()
device = get_current_device()
# == init exp_dir ==
model_name = None #'prior'
exp_name, exp_dir = define_experiment_workspace(cfg, model_name=model_name)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
save_total_limit = cfg.get("save_total_limit", None)
# == init logger, tensorboard & wandb ==
logger = create_logger(exp_dir)
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20),
)
booster = Booster(plugin=plugin)
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == load preprocessed data ==
load_va_features = cfg.get('load_va_features', False)
save_data = cfg.get('save_data', None)
if save_data is not None:
os.makedirs(save_data, exist_ok=True)
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS, audio_cfg=cfg.get("audio_cfg"),
load_data=cfg.get("load_data"))
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
batch_size = cfg.get("batch_size", 1)
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
total_batch_size = batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
logger.info("Total batch size: %s", total_batch_size)
num_steps_per_epoch = len(dataloader)
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build video vae model ==
vae = build_module(cfg.get("vae", None), MODELS)
if vae is not None:
vae = vae.to(device, dtype).eval()
vae_out_channels = vae.out_channels
else:
vae_out_channels = cfg.get("vae_out_channels", 4)
# == build audio vae model ==
audio_vae = build_module(cfg.audio_vae, MODELS, device=device, dtype=dtype)
if audio_vae is None:
audio_vae_out_channels = cfg.get('audio_vae_out_channels', 8)
else:
audio_vae_out_channels = audio_vae.vae_out_channels
# == build st-prior model ==
model = build_module(cfg.model, MODELS,
video_in_channel=vae_out_channels,
audio_in_channel=audio_vae_out_channels,
).to(device, dtype).train()
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"[ST-Prior] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# == setup prior optimizer ==
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-5),
weight_decay=cfg.get("weight_decay", 0),
eps=cfg.get("adam_eps", 1e-8),
)
lr_scheduler = None
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
if load_va_features:
for m in [vae, audio_vae]:
if m is None:
del m
torch.cuda.empty_cache()
# =======================================================
# 4. distributed training preparation with colossalai
# =======================================================
logger.info("Preparing for distributed training...")
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
logger.info("Boosting model for distributed training")
# == global variables ==
cfg_epochs = cfg.get("epochs", 1000)
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
running_loss = running_spatial_loss = running_temporal_loss = 0.0
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == resume ==
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
start_epoch, start_step = load(
booster,
cfg.load,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler,
)
dist.barrier()
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
# =======================================================
# 5. training loop
# =======================================================
dist.barrier()
timers = {}
timer_keys = ["load_data", "move_data", "encode", "forward", "backward"]
for key in timer_keys:
if record_time:
timers[key] = Timer(key, coordinator=coordinator)
else:
timers[key] = nullcontext()
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
sampler.set_epoch(epoch)
dataiter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
with tqdm(
enumerate(dataiter, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
for step, batch in pbar:
# pbar = iter(pbar)
# while True:
timer_list = []
# with timers["load_data"] as load_data_t:
# step, batch = next(pbar)
# timer_list.append(load_data_t)
bs = batch['video'].shape[0]
neg_num = list(batch['neg_videos'].values())[0].shape[1]
with timers["move_data"] as move_data_t:
vx = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
ax = batch.pop("audio").to(device, dtype) # [B, 1, T, S]
# [BxN, C, T, H, W]
neg_vx = {aug_type: aug_vx.flatten(0, 1).to(device, dtype) \
for aug_type, aug_vx in batch.pop('neg_videos').items()}
# [BxN, 1, T, S]
neg_ax = {aug_type: aug_ax.flatten(0, 1).to(device, dtype) \
for aug_type, aug_ax in batch.pop('neg_audios').items()}
timer_list.append(move_data_t)
# # == mixed training setting ==
# mixed_strategy = cfg.get("mixed_strategy", None)
# if mixed_strategy == "mixed_video_image":
# if random.random() < cfg.get("mixed_image_ratio", 0.0):
# x = x[:, :, :1, :, :]
# elif mixed_strategy == "mixed_video_random":
# length = random.randint(1, x.size(2))
# x = x[:, :, :length, :, :]
# == vae encoding ==
with timers["encode"] as encode_t:
if load_va_features:
vdims = vx.shape[1:]
neg_vx = {aug_type: aug_vx.view(bs, neg_num, *vdims) for \
aug_type, aug_vx in neg_vx.items()}
adims = ax.shape[1:]
neg_ax = {aug_type: aug_ax.view(bs, neg_num, *adims) for \
aug_type, aug_ax in neg_ax.items()}
else:
size_list = [vx.shape[0], *[v.shape[0] for v in neg_vx.values()]]
with torch.no_grad():
for x, neg_x, encode_func in \
[[vx, neg_vx, vae.encode], [ax, neg_ax, audio_vae.encode_audio]]:
x = torch.cat([x, *list(neg_x.values())], dim=0)
x = encode_func(x)
x_list = x.split(size_list, dim=0)
dims = x_list[0].shape[1:]
for i, aug_type in enumerate(neg_x.keys()):
neg_x[aug_type] = x_list[i+1].view(bs, neg_num, *dims)
neg_x['raw'] = x_list[0]
vx, ax = neg_vx.pop('raw'), neg_ax.pop('raw')
timer_list.append(encode_t)
# == prior extraction & loss calculation ==
with timers["forward"] as forward_t:
text = batch.pop('text')
kwargs = {
'mode': 'calc_loss',
'video': vx, 'audio': ax,
'neg_videos': neg_vx, 'neg_audios': neg_ax,
'frame_width': batch.get('width'),
'frame_height': batch.get('height'),
}
if batch.get('onset', None) is not None:
kwargs.update({'onset': batch['onset'].to(device, dtype)})
prior_loss, log_dict = model(text, **kwargs)
timer_list.append(forward_t)
# == generator backward & update ==
with timers["backward"] as backward_t:
optimizer.zero_grad()
booster.backward(loss=prior_loss, optimizer=optimizer)
optimizer.step()
all_reduce_mean(prior_loss)
running_loss += prior_loss.item()
timer_list.append(backward_t)
# == update log info ==
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
avg_loss = running_loss / log_step
# progress bar
# pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
logger.info({"loss": f'{avg_loss:.3f}','step': step, "global_step": global_step, \
**{k: f'{v:.3f}' for k, v in log_dict.items()}})
# tensorboard
tb_writer.add_scalar("loss", prior_loss.item(), global_step)
for k, v in log_dict.items():
tb_writer.add_scalar(k, v, global_step)
# wandb
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": prior_loss.item(),
"avg_loss": avg_loss,
**log_dict,
},
step=global_step,
)
running_loss = 0.0
log_step = 0
# == checkpoint saving ==
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
save(
booster,
exp_dir,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
sampler=sampler,
)
dist.barrier()
logger.info(
"Saved checkpoint at epoch %s step %s global_step %s to %s",
epoch,
step + 1,
global_step + 1,
exp_dir,
)
if dist.get_rank() == 0:
exp_dir_list = glob(os.path.join(exp_dir, 'epoch*-global_step*'))
exp_dir_list.sort(key=lambda x: int(re.search(r'global_step(\d+)', x).group(1)) if re.search(r'global_step(\d+)', x) else float('inf'))
if save_total_limit is not None and len(exp_dir_list) > save_total_limit:
checkpoint = exp_dir_list[0]
shutil.rmtree(checkpoint, ignore_errors=True)
logger.info(f"{checkpoint} has been deleted successfully as cfg.save_total_limit!")
dist.barrier()
if record_time and dist.get_rank() == 0:
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
logger.info(log_str)
if step >= num_steps_per_epoch:
break
sampler.reset()
start_step = 0
if __name__ == "__main__":
main()