JAV-Gen / scripts /train.py
kaiw7's picture
Upload folder using huggingface_hub
e490e7e verified
import os
from contextlib import nullcontext
from copy import deepcopy
from datetime import timedelta
from pprint import pformat
from glob import glob
import re
import shutil
import pdb
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.distributed as dist
import wandb
import colossalai
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from peft import LoraConfig
from tqdm import tqdm
from javisdit.acceleration.checkpoint import set_grad_checkpoint
from javisdit.acceleration.parallel_states import get_data_parallel_group
from javisdit.datasets.datasets import VariableVideoTextDataset
from javisdit.datasets.dataloader import prepare_dataloader
from javisdit.registry import DATASETS, MODELS, SCHEDULERS, build_module
from javisdit.utils.ckpt_utils import load, load_checkpoint, model_gathering, model_sharding, record_model_param_shape, save
from javisdit.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
from javisdit.utils.lr_scheduler import LinearWarmupLR
from javisdit.utils.misc import (
Timer,
all_reduce_mean,
create_logger,
create_tensorboard_writer,
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
check_exist_pickle,
)
from javisdit.utils.train_utils import VAMaskGenerator, create_colossalai_plugin, update_ema
def main():
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=True)
record_time = cfg.get("record_time", False)
start_from_scratch = cfg.get("start_from_scratch", False)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
if cfg.get('host'):
colossalai.launch_from_openmpi(cfg.host, cfg.port)
else:
dist.init_process_group(backend="nccl", timeout=timedelta(minutes=5)) # hours=24
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(cfg.get("seed", 1024))
coordinator = DistCoordinator()
device = get_current_device()
# == init exp_dir ==
model_name = cfg.model["type"].replace("/", "-")
exp_name, exp_dir = define_experiment_workspace(cfg, model_name=model_name)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
save_total_limit = cfg.get("save_total_limit", None)
# == init logger, tensorboard & wandb ==
logger = create_logger(exp_dir)
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(project="Open-Sora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20),
)
booster = Booster(plugin=plugin)
torch.set_num_threads(1)
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == load preprocessed data ==
load_va_features = cfg.get('load_va_features', False)
audio_only = cfg.get('audio_only', False)
# == build dataset == TODO: Oct, for audio part
dataset = build_module(cfg.dataset, DATASETS, audio_cfg=cfg.get("audio_cfg"), audio_only=audio_only,
load_data=cfg.get("load_data"))
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader == TODO: Oct, for audio part
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
)
if cfg.get("load", None) is not None and isinstance(dataset, VariableVideoTextDataset) and not start_from_scratch:
sampler_dict = torch.load(os.path.join(cfg.load, "sampler"))
last_micro_batch_access_index = sampler_dict['last_micro_batch_access_index']
dataloader_args['sampler_kwargs'] = {'last_micro_batch_access_index': last_micro_batch_access_index}
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = len(dataloader)
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build text-encoder ==
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
if text_encoder is not None:
text_encoder_output_dim = text_encoder.output_dim
text_encoder_model_max_length = text_encoder.model_max_length
else:
text_encoder_output_dim = cfg.get("text_encoder_output_dim", 4096)
text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300)
# == build prior-encoder ==
prior_encoder = build_module(cfg.get('prior_encoder', None), MODELS)
if prior_encoder is not None:
prior_encoder = prior_encoder.to(device, dtype).eval()
# == build video vae == # TODO
vae = build_module(cfg.get("vae", None), MODELS)
if vae is not None:
vae = vae.to(device, dtype).eval()
if getattr(dataset, "num_frames", None) is not None:
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
else:
latent_size = (None, None, None)
vae_out_channels = vae.out_channels
else:
latent_size = (None, None, None)
vae_out_channels = cfg.get("vae_out_channels", 4)
# == build audio vae ==
audio_vae = build_module(cfg.audio_vae, MODELS, device=device, dtype=dtype)
if audio_vae is None:
audio_vae_out_channels = cfg.get('audio_vae_out_channels', 8)
else:
audio_vae_out_channels = audio_vae.vae_out_channels
#print(audio_vae_out_channels) #TODO : 8
# == build javisdit diffusion model ==
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae_out_channels,
audio_in_channels=audio_vae_out_channels,
caption_channels=text_encoder_output_dim,
model_max_length=text_encoder_model_max_length,
enable_sequence_parallelism=cfg.get("sp_size", 1) > 1,
)
.to(device, dtype)
.train()
)
# == setup lora ==
lora_enabled = cfg.get("lora_enabled", False)
if lora_enabled:
# Ugly: enable lora will make all of original parameters freezed, free them again
trainable_list = []
for name, param in model.named_parameters():
if param.requires_grad:
trainable_list.append(f'base_model.model.{name}')
lora_pretrained_dir = cfg.get("lora_pretrained_dir", None)
if lora_pretrained_dir is None:
lora_config = LoraConfig(
r=cfg.get('lora_r', 16),
lora_alpha=cfg.get('lora_alpha', 16),
target_modules=cfg.get('lora_target_modules', []),
lora_dropout=cfg.get('lora_dropout', 0),
)
else:
logger.info(f"Loading lora config and weights from {lora_pretrained_dir}")
lora_config = None
model = booster.enable_lora(model, pretrained_dir=lora_pretrained_dir, lora_config=lora_config)
lora_pretrained_path = cfg.get("lora_pretrained_path", None)
if lora_pretrained_path is not None:
lora_state_dict = torch.load(lora_pretrained_path, map_location='cpu')
lora_state_dict = {k.replace('.weight', '.default.weight'): v for k, v in lora_state_dict.items()}
missing_keys, unexpected_keys = model.load_state_dict(lora_state_dict, strict=False)
logger.info(f"{len(lora_state_dict)-len(unexpected_keys)}/{len(lora_state_dict)} keys loaded from {lora_pretrained_path}.")
for name, param in model.named_parameters():
if name in trainable_list:
param.requires_grad_(True)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# == build ema for model ==
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
ema.eval()
update_ema(ema, model, decay=0, sharded=False)
# == DPO training ==
dpo_enabled = cfg.get("dpo_enabled", False)
if dpo_enabled:
dpo_beta = cfg.get("dpo_beta", 500)
ref_model = deepcopy(model)
ref_model.requires_grad_(False)
ref_model.eval()
else:
dpo_beta, ref_model = None, None
# == setup loss function, build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# == setup optimizer ==
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
eps=cfg.get("adam_eps", 1e-8),
)
warmup_steps = cfg.get("warmup_steps", None)
if warmup_steps is None:
lr_scheduler = None
else:
lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=cfg.get("warmup_steps"))
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
if cfg.get("mask_ratios", None) is not None:
mask_generator = VAMaskGenerator(cfg.mask_ratios)
if load_va_features:
for m in [vae, audio_vae]:
if m is None:
del m
torch.cuda.empty_cache()
# =======================================================
# 4. distributed training preparation with colossalai
# =======================================================
logger.info("Preparing for distributed training...")
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
logger.info("Boosting model for distributed training")
# == global variables ==
cfg_epochs = cfg.get("epochs", 1000)
start_epoch = start_step = log_step = acc_step = 0
running_loss_dict = {'loss': 0.0}
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == resume ==
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
ret = load(
booster,
cfg.load,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=None if start_from_scratch else lr_scheduler,
sampler=None if start_from_scratch else sampler,
)
if not start_from_scratch:
start_epoch, start_step = ret
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
model_sharding(ema)
# == prepare negprompt text embedding ==
if cfg.get('neg_prompt', None) is not None:
y_null_model_args = text_encoder.encode([cfg.neg_prompt]) # "y" and "mask"
y_null_model_args['y_null'] = y_null_model_args.pop('y', None) # avoid confiliction with "y"
y_null_model_args['mask_null'] = y_null_model_args.pop('mask', None)
# Auto-broadcast, including DPO mode
logger.info(f'Using neg_prompt for classifier-free gudiance training: {cfg.neg_prompt} ')
# =======================================================
# 5. training loop
# =======================================================
dist.barrier()
timers = {}
timer_keys = [
"move_data", "encode", "mask", "diffusion", "backward", "update_ema", "reduce_loss",
]
for key in timer_keys:
if record_time:
timers[key] = Timer(key, coordinator=coordinator)
else:
timers[key] = nullcontext()
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
initial=start_step,
total=num_steps_per_epoch,
ncols=50
) as pbar:
for step, batch in pbar:
timer_list = []
with timers["move_data"] as move_data_t:
x = batch.pop("video").to(device, dtype) # [B, C, Tv, H, W]
ax = batch.pop("audio").to(device, dtype) # [B, 1, Ta, M] , TODO [B, 1, Ta]
if dpo_enabled:
x_rej = batch.pop("video_reject").to(device, dtype) # [B, C, Tv, H, W]
ax_rej = batch.pop("audio_reject").to(device, dtype) # [B, 1, Ta, M]
x = torch.cat((x, x_rej), dim=0) # [B*2, C, Tv, H, W]
ax = torch.cat((ax, ax_rej), dim=0) # [B*2, 1, Ta, M]
batch_num_frames = batch['num_frames']
batch_fps = batch['fps']
batch_audio_fps = batch['audio_fps'] #TODO
#print(batch_audio_fps) #TODO [16k,....]
batch_duration = batch_num_frames / batch_fps
assert len(torch.unique(batch_duration)) == 1, 'variable durations temporally unsupported'
y, raw_text = batch.get("text"), batch.get('raw_text', batch.get("text"))
if record_time:
timer_list.append(move_data_t)
# == visual and text encoding ==
with timers["encode"] as encode_t:
with torch.no_grad():
# Prepare visual and audio inputs
if audio_only: # fake x
x = x.repeat(1, vae_out_channels, 1, 1, 1)
if load_va_features:
x = x.to(device, dtype)
ax = ax.to(device, dtype)
else:
if not audio_only:
x = vae.encode(x) # [B, C, T, H/P, W/P]
# print(ax.dtype)
#print(ax.shape)
ax = audio_vae.encode_audio(ax, batch_audio_fps[0]) # [B, C, T, M] #TODO: Oct for audioldm, input is audio spec
#print(ax.shape) #TODO audioldm2: [B, 8, T, D], hunyuan: [B, T, D]
# print(ax.dtype)
# Prepare text inputs
if cfg.get("load_text_features", False):
model_args = {"y": y.to(device, dtype)}
mask = batch.pop("mask")
if isinstance(mask, torch.Tensor):
mask = mask.to(device, dtype)
model_args["mask"] = mask
else:
model_args = text_encoder.encode(y)
if dpo_enabled:
model_args["mask"] = torch.cat([model_args["mask"], model_args["mask"]], dim=0)
model_args["y"] = torch.cat([model_args["y"], model_args["y"]], dim=0)
# Prepare spatio-temporal prior
if prior_encoder is not None:
assert not dpo_enabled, "NotImplemented"
model_args.update(prior_encoder.encode(raw_text))
if record_time:
timer_list.append(encode_t)
# == temporal mask ==
with timers["mask"] as mask_t:
mask, ax_mask = None, None
if cfg.get("mask_ratios", None) is not None:
mask, ax_mask = mask_generator.get_masks(x, ax) # shape(B, T)
if dpo_enabled:
mask = torch.cat([mask, mask], dim=0)
ax_mask = torch.cat([ax_mask, ax_mask], dim=0)
model_args["x_mask"] = mask
model_args["ax_mask"] = ax_mask
if record_time:
timer_list.append(mask_t)
# == video meta info ==
for k, v in batch.items():
if isinstance(v, torch.Tensor):
model_args[k] = v.to(device, dtype)
# == prepare neg prompt text embeddings args ==
if cfg.get('neg_prompt', None) is not None:
model_args.update(y_null_model_args)
# == prepare training mode args ==
model_args.update({
'audio_only': audio_only,
'dpo_enabled': dpo_enabled, 'dpo_beta': dpo_beta, 'ref_model': ref_model
})
# == diffusion loss computation ==
with timers["diffusion"] as loss_t:
# loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
x = {'video': x, 'audio': ax}
mask = {'video': mask, 'audio': ax_mask}
loss_dict = scheduler.multimodal_training_losses(model, x, model_args, mask=mask)
if record_time:
timer_list.append(loss_t)
# == backward & update ==
with timers["backward"] as backward_t:
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
# update learning rate
if lr_scheduler is not None:
lr_scheduler.step()
if record_time:
timer_list.append(backward_t)
# == update EMA ==
with timers["update_ema"] as ema_t:
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
if record_time:
timer_list.append(ema_t)
# == update log info ==
with timers["reduce_loss"] as reduce_loss_t:
all_reduce_mean(loss)
running_loss_dict['loss'] += loss.item()
for k, v in loss_dict.items():
if k != "loss":
if k not in running_loss_dict:
running_loss_dict[k] = 0.0
running_loss_dict[k] += all_reduce_mean(v).item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
if record_time:
timer_list.append(reduce_loss_t)
# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
avg_loss = {}
for k, v in running_loss_dict.items():
avg_loss[k] = v / log_step
# progress bar
print_loss = {k: f"{v:.4f}" for k, v in avg_loss.items()}
pbar.set_postfix({**print_loss, "step": step, "global_step": global_step})
logger.info({**print_loss, "step": step, "global_step": global_step})
# tensorboard
for k, v in avg_loss.items():
tb_writer.add_scalar(k, v, global_step)
# wandb
if cfg.get("wandb", False):
wandb_dict = {
"iter": global_step,
"acc_step": acc_step,
"epoch": epoch,
"loss": loss.item(),
**{f"avg_loss_{k}": v for k, v in avg_loss.items()},
"lr": optimizer.param_groups[0]["lr"],
}
if record_time:
wandb_dict.update(
{
"debug/move_data_time": move_data_t.elapsed_time,
"debug/encode_time": encode_t.elapsed_time,
"debug/mask_time": mask_t.elapsed_time,
"debug/diffusion_time": loss_t.elapsed_time,
"debug/backward_time": backward_t.elapsed_time,
"debug/update_ema_time": ema_t.elapsed_time,
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
}
)
wandb.log(wandb_dict, step=global_step)
running_loss_dict = {"loss": 0.0}
log_step = 0
# == checkpoint saving ==
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
model_gathering(ema, ema_shape_dict)
dist.barrier()
save_dir = save(
booster,
exp_dir,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
lora_enabled=lora_enabled,
lora_dir=cfg.get("lora_dir", "lora")
)
if dist.get_rank() == 0:
model_sharding(ema)
logger.info(
"Saved checkpoint at epoch %s, step %s, global_step %s to %s",
epoch,
step + 1,
global_step + 1,
save_dir,
)
exp_dir_list = glob(os.path.join(exp_dir, 'epoch*-global_step*'))
exp_dir_list.sort(key=lambda x: int(re.search(r'global_step(\d+)', x).group(1)) if re.search(r'global_step(\d+)', x) else float('inf'))
if save_total_limit is not None and len(exp_dir_list) > save_total_limit:
checkpoint = exp_dir_list[0]
shutil.rmtree(checkpoint, ignore_errors=True)
logger.info(f"{checkpoint} has been deleted successfully as cfg.save_total_limit!")
dist.barrier()
if record_time:
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
logger.info(log_str)
sampler.reset()
start_step = 0
torch.cuda.empty_cache()
model_gathering(ema, ema_shape_dict)
save_dir = save(
booster, exp_dir,
model=model, ema=ema, optimizer=optimizer, lr_scheduler=lr_scheduler, sampler=sampler,
epoch=epoch, step=step + 1, global_step=global_step + 1, batch_size=cfg.get("batch_size", None),
lora_enabled=lora_enabled, lora_dir=cfg.get("lora_dir", "lora")
)
logger.info(
"Saved final checkpoint at epoch %s, step %s, global_step %s to %s",
epoch, step + 1, global_step + 1, save_dir,
)
dist.barrier()
if __name__ == "__main__":
main()