|
|
import os |
|
|
from contextlib import nullcontext |
|
|
from copy import deepcopy |
|
|
from datetime import timedelta |
|
|
from pprint import pformat |
|
|
from glob import glob |
|
|
import re |
|
|
import shutil |
|
|
import pdb |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import wandb |
|
|
import colossalai |
|
|
from colossalai.booster import Booster |
|
|
from colossalai.cluster import DistCoordinator |
|
|
from colossalai.nn.optimizer import HybridAdam |
|
|
from colossalai.utils import get_current_device, set_seed |
|
|
from peft import LoraConfig |
|
|
from tqdm import tqdm |
|
|
|
|
|
from javisdit.acceleration.checkpoint import set_grad_checkpoint |
|
|
from javisdit.acceleration.parallel_states import get_data_parallel_group |
|
|
from javisdit.datasets.datasets import VariableVideoTextDataset |
|
|
from javisdit.datasets.dataloader import prepare_dataloader |
|
|
from javisdit.registry import DATASETS, MODELS, SCHEDULERS, build_module |
|
|
from javisdit.utils.ckpt_utils import load, load_checkpoint, model_gathering, model_sharding, record_model_param_shape, save |
|
|
from javisdit.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config |
|
|
from javisdit.utils.lr_scheduler import LinearWarmupLR |
|
|
from javisdit.utils.misc import ( |
|
|
Timer, |
|
|
all_reduce_mean, |
|
|
create_logger, |
|
|
create_tensorboard_writer, |
|
|
format_numel_str, |
|
|
get_model_numel, |
|
|
requires_grad, |
|
|
to_torch_dtype, |
|
|
check_exist_pickle, |
|
|
) |
|
|
from javisdit.utils.train_utils import VAMaskGenerator, create_colossalai_plugin, update_ema |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg = parse_configs(training=True) |
|
|
record_time = cfg.get("record_time", False) |
|
|
start_from_scratch = cfg.get("start_from_scratch", False) |
|
|
|
|
|
|
|
|
assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
|
|
cfg_dtype = cfg.get("dtype", "bf16") |
|
|
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" |
|
|
dtype = to_torch_dtype(cfg.get("dtype", "bf16")) |
|
|
|
|
|
|
|
|
|
|
|
if cfg.get('host'): |
|
|
colossalai.launch_from_openmpi(cfg.host, cfg.port) |
|
|
else: |
|
|
dist.init_process_group(backend="nccl", timeout=timedelta(minutes=5)) |
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
|
|
set_seed(cfg.get("seed", 1024)) |
|
|
coordinator = DistCoordinator() |
|
|
device = get_current_device() |
|
|
|
|
|
|
|
|
model_name = cfg.model["type"].replace("/", "-") |
|
|
exp_name, exp_dir = define_experiment_workspace(cfg, model_name=model_name) |
|
|
coordinator.block_all() |
|
|
if coordinator.is_master(): |
|
|
os.makedirs(exp_dir, exist_ok=True) |
|
|
save_training_config(cfg.to_dict(), exp_dir) |
|
|
coordinator.block_all() |
|
|
save_total_limit = cfg.get("save_total_limit", None) |
|
|
|
|
|
|
|
|
logger = create_logger(exp_dir) |
|
|
logger.info("Experiment directory created at %s", exp_dir) |
|
|
logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) |
|
|
if coordinator.is_master(): |
|
|
tb_writer = create_tensorboard_writer(exp_dir) |
|
|
if cfg.get("wandb", False): |
|
|
wandb.init(project="Open-Sora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb") |
|
|
|
|
|
|
|
|
plugin = create_colossalai_plugin( |
|
|
plugin=cfg.get("plugin", "zero2"), |
|
|
dtype=cfg_dtype, |
|
|
grad_clip=cfg.get("grad_clip", 0), |
|
|
sp_size=cfg.get("sp_size", 1), |
|
|
reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20), |
|
|
) |
|
|
booster = Booster(plugin=plugin) |
|
|
torch.set_num_threads(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Building dataset...") |
|
|
|
|
|
|
|
|
load_va_features = cfg.get('load_va_features', False) |
|
|
audio_only = cfg.get('audio_only', False) |
|
|
|
|
|
|
|
|
dataset = build_module(cfg.dataset, DATASETS, audio_cfg=cfg.get("audio_cfg"), audio_only=audio_only, |
|
|
load_data=cfg.get("load_data")) |
|
|
logger.info("Dataset contains %s samples.", len(dataset)) |
|
|
|
|
|
|
|
|
dataloader_args = dict( |
|
|
dataset=dataset, |
|
|
batch_size=cfg.get("batch_size", None), |
|
|
num_workers=cfg.get("num_workers", 4), |
|
|
seed=cfg.get("seed", 1024), |
|
|
shuffle=True, |
|
|
drop_last=True, |
|
|
pin_memory=True, |
|
|
process_group=get_data_parallel_group(), |
|
|
prefetch_factor=cfg.get("prefetch_factor", None), |
|
|
) |
|
|
if cfg.get("load", None) is not None and isinstance(dataset, VariableVideoTextDataset) and not start_from_scratch: |
|
|
sampler_dict = torch.load(os.path.join(cfg.load, "sampler")) |
|
|
last_micro_batch_access_index = sampler_dict['last_micro_batch_access_index'] |
|
|
dataloader_args['sampler_kwargs'] = {'last_micro_batch_access_index': last_micro_batch_access_index} |
|
|
dataloader, sampler = prepare_dataloader( |
|
|
bucket_config=cfg.get("bucket_config", None), |
|
|
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), |
|
|
**dataloader_args, |
|
|
) |
|
|
num_steps_per_epoch = len(dataloader) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Building models...") |
|
|
|
|
|
|
|
|
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype) |
|
|
if text_encoder is not None: |
|
|
text_encoder_output_dim = text_encoder.output_dim |
|
|
text_encoder_model_max_length = text_encoder.model_max_length |
|
|
else: |
|
|
text_encoder_output_dim = cfg.get("text_encoder_output_dim", 4096) |
|
|
text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300) |
|
|
|
|
|
|
|
|
prior_encoder = build_module(cfg.get('prior_encoder', None), MODELS) |
|
|
if prior_encoder is not None: |
|
|
prior_encoder = prior_encoder.to(device, dtype).eval() |
|
|
|
|
|
|
|
|
vae = build_module(cfg.get("vae", None), MODELS) |
|
|
if vae is not None: |
|
|
vae = vae.to(device, dtype).eval() |
|
|
if getattr(dataset, "num_frames", None) is not None: |
|
|
input_size = (dataset.num_frames, *dataset.image_size) |
|
|
latent_size = vae.get_latent_size(input_size) |
|
|
else: |
|
|
latent_size = (None, None, None) |
|
|
vae_out_channels = vae.out_channels |
|
|
else: |
|
|
latent_size = (None, None, None) |
|
|
vae_out_channels = cfg.get("vae_out_channels", 4) |
|
|
|
|
|
|
|
|
audio_vae = build_module(cfg.audio_vae, MODELS, device=device, dtype=dtype) |
|
|
if audio_vae is None: |
|
|
audio_vae_out_channels = cfg.get('audio_vae_out_channels', 8) |
|
|
else: |
|
|
audio_vae_out_channels = audio_vae.vae_out_channels |
|
|
|
|
|
|
|
|
model = ( |
|
|
build_module( |
|
|
cfg.model, |
|
|
MODELS, |
|
|
input_size=latent_size, |
|
|
in_channels=vae_out_channels, |
|
|
audio_in_channels=audio_vae_out_channels, |
|
|
caption_channels=text_encoder_output_dim, |
|
|
model_max_length=text_encoder_model_max_length, |
|
|
enable_sequence_parallelism=cfg.get("sp_size", 1) > 1, |
|
|
) |
|
|
.to(device, dtype) |
|
|
.train() |
|
|
) |
|
|
|
|
|
|
|
|
lora_enabled = cfg.get("lora_enabled", False) |
|
|
if lora_enabled: |
|
|
|
|
|
trainable_list = [] |
|
|
for name, param in model.named_parameters(): |
|
|
if param.requires_grad: |
|
|
trainable_list.append(f'base_model.model.{name}') |
|
|
|
|
|
lora_pretrained_dir = cfg.get("lora_pretrained_dir", None) |
|
|
if lora_pretrained_dir is None: |
|
|
lora_config = LoraConfig( |
|
|
r=cfg.get('lora_r', 16), |
|
|
lora_alpha=cfg.get('lora_alpha', 16), |
|
|
target_modules=cfg.get('lora_target_modules', []), |
|
|
lora_dropout=cfg.get('lora_dropout', 0), |
|
|
) |
|
|
else: |
|
|
logger.info(f"Loading lora config and weights from {lora_pretrained_dir}") |
|
|
lora_config = None |
|
|
model = booster.enable_lora(model, pretrained_dir=lora_pretrained_dir, lora_config=lora_config) |
|
|
|
|
|
lora_pretrained_path = cfg.get("lora_pretrained_path", None) |
|
|
if lora_pretrained_path is not None: |
|
|
lora_state_dict = torch.load(lora_pretrained_path, map_location='cpu') |
|
|
lora_state_dict = {k.replace('.weight', '.default.weight'): v for k, v in lora_state_dict.items()} |
|
|
missing_keys, unexpected_keys = model.load_state_dict(lora_state_dict, strict=False) |
|
|
logger.info(f"{len(lora_state_dict)-len(unexpected_keys)}/{len(lora_state_dict)} keys loaded from {lora_pretrained_path}.") |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if name in trainable_list: |
|
|
param.requires_grad_(True) |
|
|
|
|
|
model_numel, model_numel_trainable = get_model_numel(model) |
|
|
logger.info( |
|
|
"Trainable model params: %s, Total model params: %s", |
|
|
format_numel_str(model_numel_trainable), |
|
|
format_numel_str(model_numel), |
|
|
) |
|
|
|
|
|
|
|
|
ema = deepcopy(model).to(torch.float32).to(device) |
|
|
requires_grad(ema, False) |
|
|
ema_shape_dict = record_model_param_shape(ema) |
|
|
ema.eval() |
|
|
update_ema(ema, model, decay=0, sharded=False) |
|
|
|
|
|
|
|
|
dpo_enabled = cfg.get("dpo_enabled", False) |
|
|
if dpo_enabled: |
|
|
dpo_beta = cfg.get("dpo_beta", 500) |
|
|
ref_model = deepcopy(model) |
|
|
ref_model.requires_grad_(False) |
|
|
ref_model.eval() |
|
|
else: |
|
|
dpo_beta, ref_model = None, None |
|
|
|
|
|
|
|
|
scheduler = build_module(cfg.scheduler, SCHEDULERS) |
|
|
|
|
|
|
|
|
optimizer = HybridAdam( |
|
|
filter(lambda p: p.requires_grad, model.parameters()), |
|
|
adamw_mode=True, |
|
|
lr=cfg.get("lr", 1e-4), |
|
|
weight_decay=cfg.get("weight_decay", 0), |
|
|
eps=cfg.get("adam_eps", 1e-8), |
|
|
) |
|
|
warmup_steps = cfg.get("warmup_steps", None) |
|
|
if warmup_steps is None: |
|
|
lr_scheduler = None |
|
|
else: |
|
|
lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=cfg.get("warmup_steps")) |
|
|
|
|
|
|
|
|
if cfg.get("grad_checkpoint", False): |
|
|
set_grad_checkpoint(model) |
|
|
if cfg.get("mask_ratios", None) is not None: |
|
|
mask_generator = VAMaskGenerator(cfg.mask_ratios) |
|
|
if load_va_features: |
|
|
for m in [vae, audio_vae]: |
|
|
if m is None: |
|
|
del m |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Preparing for distributed training...") |
|
|
|
|
|
|
|
|
torch.set_default_dtype(dtype) |
|
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost( |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
lr_scheduler=lr_scheduler, |
|
|
dataloader=dataloader, |
|
|
) |
|
|
torch.set_default_dtype(torch.float) |
|
|
logger.info("Boosting model for distributed training") |
|
|
|
|
|
|
|
|
cfg_epochs = cfg.get("epochs", 1000) |
|
|
start_epoch = start_step = log_step = acc_step = 0 |
|
|
running_loss_dict = {'loss': 0.0} |
|
|
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch) |
|
|
|
|
|
|
|
|
if cfg.get("load", None) is not None: |
|
|
logger.info("Loading checkpoint") |
|
|
ret = load( |
|
|
booster, |
|
|
cfg.load, |
|
|
model=model, |
|
|
ema=ema, |
|
|
optimizer=optimizer, |
|
|
lr_scheduler=None if start_from_scratch else lr_scheduler, |
|
|
sampler=None if start_from_scratch else sampler, |
|
|
) |
|
|
if not start_from_scratch: |
|
|
start_epoch, start_step = ret |
|
|
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step) |
|
|
|
|
|
model_sharding(ema) |
|
|
|
|
|
|
|
|
if cfg.get('neg_prompt', None) is not None: |
|
|
y_null_model_args = text_encoder.encode([cfg.neg_prompt]) |
|
|
y_null_model_args['y_null'] = y_null_model_args.pop('y', None) |
|
|
y_null_model_args['mask_null'] = y_null_model_args.pop('mask', None) |
|
|
|
|
|
logger.info(f'Using neg_prompt for classifier-free gudiance training: {cfg.neg_prompt} ') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
timers = {} |
|
|
timer_keys = [ |
|
|
"move_data", "encode", "mask", "diffusion", "backward", "update_ema", "reduce_loss", |
|
|
] |
|
|
for key in timer_keys: |
|
|
if record_time: |
|
|
timers[key] = Timer(key, coordinator=coordinator) |
|
|
else: |
|
|
timers[key] = nullcontext() |
|
|
for epoch in range(start_epoch, cfg_epochs): |
|
|
|
|
|
sampler.set_epoch(epoch) |
|
|
dataloader_iter = iter(dataloader) |
|
|
logger.info("Beginning epoch %s...", epoch) |
|
|
|
|
|
|
|
|
with tqdm( |
|
|
enumerate(dataloader_iter, start=start_step), |
|
|
desc=f"Epoch {epoch}", |
|
|
disable=not coordinator.is_master(), |
|
|
initial=start_step, |
|
|
total=num_steps_per_epoch, |
|
|
ncols=50 |
|
|
) as pbar: |
|
|
for step, batch in pbar: |
|
|
timer_list = [] |
|
|
with timers["move_data"] as move_data_t: |
|
|
x = batch.pop("video").to(device, dtype) |
|
|
ax = batch.pop("audio").to(device, dtype) |
|
|
if dpo_enabled: |
|
|
x_rej = batch.pop("video_reject").to(device, dtype) |
|
|
ax_rej = batch.pop("audio_reject").to(device, dtype) |
|
|
x = torch.cat((x, x_rej), dim=0) |
|
|
ax = torch.cat((ax, ax_rej), dim=0) |
|
|
batch_num_frames = batch['num_frames'] |
|
|
batch_fps = batch['fps'] |
|
|
batch_audio_fps = batch['audio_fps'] |
|
|
|
|
|
batch_duration = batch_num_frames / batch_fps |
|
|
assert len(torch.unique(batch_duration)) == 1, 'variable durations temporally unsupported' |
|
|
y, raw_text = batch.get("text"), batch.get('raw_text', batch.get("text")) |
|
|
if record_time: |
|
|
timer_list.append(move_data_t) |
|
|
|
|
|
|
|
|
with timers["encode"] as encode_t: |
|
|
with torch.no_grad(): |
|
|
|
|
|
if audio_only: |
|
|
x = x.repeat(1, vae_out_channels, 1, 1, 1) |
|
|
if load_va_features: |
|
|
x = x.to(device, dtype) |
|
|
ax = ax.to(device, dtype) |
|
|
else: |
|
|
if not audio_only: |
|
|
x = vae.encode(x) |
|
|
|
|
|
|
|
|
ax = audio_vae.encode_audio(ax, batch_audio_fps[0]) |
|
|
|
|
|
|
|
|
|
|
|
if cfg.get("load_text_features", False): |
|
|
model_args = {"y": y.to(device, dtype)} |
|
|
mask = batch.pop("mask") |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask = mask.to(device, dtype) |
|
|
model_args["mask"] = mask |
|
|
else: |
|
|
model_args = text_encoder.encode(y) |
|
|
if dpo_enabled: |
|
|
model_args["mask"] = torch.cat([model_args["mask"], model_args["mask"]], dim=0) |
|
|
model_args["y"] = torch.cat([model_args["y"], model_args["y"]], dim=0) |
|
|
|
|
|
if prior_encoder is not None: |
|
|
assert not dpo_enabled, "NotImplemented" |
|
|
model_args.update(prior_encoder.encode(raw_text)) |
|
|
if record_time: |
|
|
timer_list.append(encode_t) |
|
|
|
|
|
|
|
|
with timers["mask"] as mask_t: |
|
|
mask, ax_mask = None, None |
|
|
if cfg.get("mask_ratios", None) is not None: |
|
|
mask, ax_mask = mask_generator.get_masks(x, ax) |
|
|
if dpo_enabled: |
|
|
mask = torch.cat([mask, mask], dim=0) |
|
|
ax_mask = torch.cat([ax_mask, ax_mask], dim=0) |
|
|
model_args["x_mask"] = mask |
|
|
model_args["ax_mask"] = ax_mask |
|
|
if record_time: |
|
|
timer_list.append(mask_t) |
|
|
|
|
|
|
|
|
for k, v in batch.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
model_args[k] = v.to(device, dtype) |
|
|
|
|
|
|
|
|
if cfg.get('neg_prompt', None) is not None: |
|
|
model_args.update(y_null_model_args) |
|
|
|
|
|
|
|
|
model_args.update({ |
|
|
'audio_only': audio_only, |
|
|
'dpo_enabled': dpo_enabled, 'dpo_beta': dpo_beta, 'ref_model': ref_model |
|
|
}) |
|
|
|
|
|
|
|
|
with timers["diffusion"] as loss_t: |
|
|
|
|
|
x = {'video': x, 'audio': ax} |
|
|
mask = {'video': mask, 'audio': ax_mask} |
|
|
loss_dict = scheduler.multimodal_training_losses(model, x, model_args, mask=mask) |
|
|
if record_time: |
|
|
timer_list.append(loss_t) |
|
|
|
|
|
|
|
|
with timers["backward"] as backward_t: |
|
|
loss = loss_dict["loss"].mean() |
|
|
booster.backward(loss=loss, optimizer=optimizer) |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
if lr_scheduler is not None: |
|
|
lr_scheduler.step() |
|
|
if record_time: |
|
|
timer_list.append(backward_t) |
|
|
|
|
|
|
|
|
with timers["update_ema"] as ema_t: |
|
|
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999)) |
|
|
if record_time: |
|
|
timer_list.append(ema_t) |
|
|
|
|
|
|
|
|
with timers["reduce_loss"] as reduce_loss_t: |
|
|
all_reduce_mean(loss) |
|
|
running_loss_dict['loss'] += loss.item() |
|
|
for k, v in loss_dict.items(): |
|
|
if k != "loss": |
|
|
if k not in running_loss_dict: |
|
|
running_loss_dict[k] = 0.0 |
|
|
running_loss_dict[k] += all_reduce_mean(v).item() |
|
|
global_step = epoch * num_steps_per_epoch + step |
|
|
log_step += 1 |
|
|
acc_step += 1 |
|
|
if record_time: |
|
|
timer_list.append(reduce_loss_t) |
|
|
|
|
|
|
|
|
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0: |
|
|
avg_loss = {} |
|
|
for k, v in running_loss_dict.items(): |
|
|
avg_loss[k] = v / log_step |
|
|
|
|
|
print_loss = {k: f"{v:.4f}" for k, v in avg_loss.items()} |
|
|
pbar.set_postfix({**print_loss, "step": step, "global_step": global_step}) |
|
|
logger.info({**print_loss, "step": step, "global_step": global_step}) |
|
|
|
|
|
for k, v in avg_loss.items(): |
|
|
tb_writer.add_scalar(k, v, global_step) |
|
|
|
|
|
if cfg.get("wandb", False): |
|
|
wandb_dict = { |
|
|
"iter": global_step, |
|
|
"acc_step": acc_step, |
|
|
"epoch": epoch, |
|
|
"loss": loss.item(), |
|
|
**{f"avg_loss_{k}": v for k, v in avg_loss.items()}, |
|
|
"lr": optimizer.param_groups[0]["lr"], |
|
|
} |
|
|
if record_time: |
|
|
wandb_dict.update( |
|
|
{ |
|
|
"debug/move_data_time": move_data_t.elapsed_time, |
|
|
"debug/encode_time": encode_t.elapsed_time, |
|
|
"debug/mask_time": mask_t.elapsed_time, |
|
|
"debug/diffusion_time": loss_t.elapsed_time, |
|
|
"debug/backward_time": backward_t.elapsed_time, |
|
|
"debug/update_ema_time": ema_t.elapsed_time, |
|
|
"debug/reduce_loss_time": reduce_loss_t.elapsed_time, |
|
|
} |
|
|
) |
|
|
wandb.log(wandb_dict, step=global_step) |
|
|
|
|
|
running_loss_dict = {"loss": 0.0} |
|
|
log_step = 0 |
|
|
|
|
|
|
|
|
ckpt_every = cfg.get("ckpt_every", 0) |
|
|
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0: |
|
|
model_gathering(ema, ema_shape_dict) |
|
|
dist.barrier() |
|
|
save_dir = save( |
|
|
booster, |
|
|
exp_dir, |
|
|
model=model, |
|
|
ema=ema, |
|
|
optimizer=optimizer, |
|
|
lr_scheduler=lr_scheduler, |
|
|
sampler=sampler, |
|
|
epoch=epoch, |
|
|
step=step + 1, |
|
|
global_step=global_step + 1, |
|
|
batch_size=cfg.get("batch_size", None), |
|
|
lora_enabled=lora_enabled, |
|
|
lora_dir=cfg.get("lora_dir", "lora") |
|
|
) |
|
|
if dist.get_rank() == 0: |
|
|
model_sharding(ema) |
|
|
|
|
|
logger.info( |
|
|
"Saved checkpoint at epoch %s, step %s, global_step %s to %s", |
|
|
epoch, |
|
|
step + 1, |
|
|
global_step + 1, |
|
|
save_dir, |
|
|
) |
|
|
|
|
|
exp_dir_list = glob(os.path.join(exp_dir, 'epoch*-global_step*')) |
|
|
exp_dir_list.sort(key=lambda x: int(re.search(r'global_step(\d+)', x).group(1)) if re.search(r'global_step(\d+)', x) else float('inf')) |
|
|
if save_total_limit is not None and len(exp_dir_list) > save_total_limit: |
|
|
checkpoint = exp_dir_list[0] |
|
|
shutil.rmtree(checkpoint, ignore_errors=True) |
|
|
logger.info(f"{checkpoint} has been deleted successfully as cfg.save_total_limit!") |
|
|
dist.barrier() |
|
|
|
|
|
if record_time: |
|
|
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | " |
|
|
for timer in timer_list: |
|
|
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | " |
|
|
logger.info(log_str) |
|
|
|
|
|
sampler.reset() |
|
|
start_step = 0 |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
model_gathering(ema, ema_shape_dict) |
|
|
save_dir = save( |
|
|
booster, exp_dir, |
|
|
model=model, ema=ema, optimizer=optimizer, lr_scheduler=lr_scheduler, sampler=sampler, |
|
|
epoch=epoch, step=step + 1, global_step=global_step + 1, batch_size=cfg.get("batch_size", None), |
|
|
lora_enabled=lora_enabled, lora_dir=cfg.get("lora_dir", "lora") |
|
|
) |
|
|
logger.info( |
|
|
"Saved final checkpoint at epoch %s, step %s, global_step %s to %s", |
|
|
epoch, step + 1, global_step + 1, save_dir, |
|
|
) |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|