|
|
import os |
|
|
from pprint import pformat |
|
|
from itertools import islice |
|
|
|
|
|
import colossalai |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from tqdm import tqdm |
|
|
|
|
|
from javisdit.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group |
|
|
from javisdit.datasets.dataloader import prepare_dataloader |
|
|
from javisdit.datasets.datasets import VariableVideoTextDataset |
|
|
from javisdit.registry import DATASETS, MODELS, build_module |
|
|
from javisdit.utils.config_utils import parse_configs, save_training_config |
|
|
from javisdit.utils.misc import FeatureSaver, Timer, create_logger, to_torch_dtype |
|
|
|
|
|
|
|
|
def main(): |
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg = parse_configs(training=False) |
|
|
|
|
|
|
|
|
assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
|
|
cfg_dtype = cfg.get("dtype", "bf16") |
|
|
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" |
|
|
dtype = to_torch_dtype(cfg.get("dtype", "bf16")) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
cfg_dtype = cfg.get("dtype", "fp32") |
|
|
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" |
|
|
dtype = to_torch_dtype(cfg.get("dtype", "bf16")) |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
colossalai.launch_from_torch({}) |
|
|
set_data_parallel_group(dist.group.WORLD) |
|
|
|
|
|
|
|
|
logger = create_logger() |
|
|
logger.info("Configuration:\n %s", pformat(cfg.to_dict())) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bin_size = cfg.bin_size |
|
|
save_text_features = cfg.get("save_text_features", False) |
|
|
save_compressed_text_features = cfg.get("save_compressed_text_features", False) |
|
|
if save_compressed_text_features: |
|
|
raise NotImplementedError |
|
|
save_text_only = cfg.get("save_text_only", False) |
|
|
|
|
|
|
|
|
start_index = cfg.get("start_index", 0) |
|
|
end_index = cfg.get("end_index") |
|
|
last_micro_batch_access_index = start_index * bin_size |
|
|
start_step = 0 |
|
|
|
|
|
|
|
|
assert cfg.get("save_dir", None) is not None, "Please specify the save_dir in the config file." |
|
|
save_dir = os.path.join(cfg.save_dir, f"s{start_index}_e{end_index}") |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
save_training_config(cfg.to_dict(), save_dir) |
|
|
logger.info("Saving features to %s", save_dir) |
|
|
|
|
|
saver = FeatureSaver(save_dir, bin_size, start_bin=start_index) |
|
|
start_step = saver.get_num_saved() |
|
|
if start_step > 0: |
|
|
logger.info(f'Found existing data. Start from step {start_step}.') |
|
|
last_micro_batch_access_index += start_step |
|
|
|
|
|
logger.info("Building dataset...") |
|
|
|
|
|
dataset = build_module(cfg.dataset, DATASETS, audio_cfg=cfg.get("audio_cfg")) |
|
|
logger.info("Dataset contains %s samples.", len(dataset)) |
|
|
|
|
|
|
|
|
dataloader_args = dict( |
|
|
dataset=dataset, |
|
|
batch_size=cfg.get("batch_size", None), |
|
|
num_workers=cfg.get("num_workers", 4), |
|
|
seed=cfg.get("seed", 1024), |
|
|
shuffle=True, |
|
|
drop_last=True, |
|
|
pin_memory=True, |
|
|
process_group=get_data_parallel_group(), |
|
|
prefetch_factor=1, |
|
|
sampler_kwargs={} |
|
|
) |
|
|
assert isinstance(dataset, VariableVideoTextDataset) |
|
|
dataloader_args['sampler_kwargs'] = {'last_micro_batch_access_index': last_micro_batch_access_index} |
|
|
|
|
|
dataloader, _ = prepare_dataloader( |
|
|
bucket_config=cfg.get("bucket_config", None), |
|
|
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), |
|
|
**dataloader_args, |
|
|
) |
|
|
num_steps_per_epoch = len(dataloader) |
|
|
|
|
|
|
|
|
|
|
|
num_bin = num_steps_per_epoch // bin_size |
|
|
logger.info("Number of batches: %s", num_steps_per_epoch) |
|
|
logger.info("Bin size: %s", bin_size) |
|
|
logger.info("Number of bins: %s", num_bin) |
|
|
num_bin_to_process = min(num_bin, end_index) - start_index |
|
|
logger.info("Start index: %s", start_index) |
|
|
logger.info("End index: %s", end_index) |
|
|
logger.info("Number of batches to process: %s", num_bin_to_process) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Building models...") |
|
|
|
|
|
text_encoder = build_module(cfg.get('text_encoder', None), MODELS, device=device, dtype=dtype) |
|
|
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() |
|
|
audio_vae = build_module(cfg.audio_vae, MODELS, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataloader_iter = iter(dataloader) |
|
|
log_time = cfg.get("log_time", False) |
|
|
total_steps = num_bin_to_process * bin_size |
|
|
for _ in tqdm(range(start_step, total_steps), initial=start_step, total=total_steps): |
|
|
with Timer("step", log=log_time): |
|
|
with Timer("data loading", log=log_time): |
|
|
batch = next(dataloader_iter) |
|
|
neg_vx, neg_ax = None, None |
|
|
if not save_text_only: |
|
|
vx = batch.pop("video").to(device, dtype) |
|
|
ax = batch.pop("audio").to(device, dtype) |
|
|
if dataset.neg_aug: |
|
|
neg_vx = {aug_type: aug_vx.flatten(0, 1).to(device, dtype) \ |
|
|
for aug_type, aug_vx in batch.pop('neg_videos').items()} |
|
|
neg_ax = {aug_type: aug_ax.flatten(0, 1).to(device, dtype) \ |
|
|
for aug_type, aug_ax in batch.pop('neg_audios').items()} |
|
|
else: |
|
|
vx, ax = None, None |
|
|
y = batch.get("text") |
|
|
raw_text = batch.get('raw_text', batch.get("text")) |
|
|
batch_num_frames = batch['num_frames'] |
|
|
batch_fps = batch['fps'] |
|
|
batch_duration = batch_num_frames / batch_fps |
|
|
assert len(torch.unique(batch_duration)) == 1, 'variable durations temporally unsupported' |
|
|
|
|
|
if not save_text_only: |
|
|
with Timer("vae", log=log_time): |
|
|
if neg_vx is not None: |
|
|
size_list = [vx.shape[0], *[v.shape[0] for v in neg_vx.values()]] |
|
|
bs, neg_num = vx.shape[0], dataset.neg_aug |
|
|
for x, neg_x, encode_func in \ |
|
|
[[vx, neg_vx, vae.encode], [ax, neg_ax, audio_vae.encode_audio]]: |
|
|
x = torch.cat([x, *list(neg_x.values())], dim=0) |
|
|
x = encode_func(x) |
|
|
x_list = x.split(size_list, dim=0) |
|
|
dims = x_list[0].shape[1:] |
|
|
for i, aug_type in enumerate(neg_x.keys()): |
|
|
neg_x[aug_type] = x_list[i+1].view(bs, neg_num, *dims) |
|
|
neg_x['raw'] = x_list[0] |
|
|
vx, ax = neg_vx.pop('raw'), neg_ax.pop('raw') |
|
|
else: |
|
|
vx = vae.encode(vx) |
|
|
ax = audio_vae.encode_audio(ax) |
|
|
with Timer("feature to cpu", log=log_time): |
|
|
vx = vx.cpu() |
|
|
ax = ax.cpu() |
|
|
if dataset.neg_aug: |
|
|
neg_vx = {k: v.cpu() for k, v in neg_vx.items()} |
|
|
neg_ax = {k: v.cpu() for k, v in neg_ax.items()} |
|
|
|
|
|
batch_dict = { |
|
|
"index": batch["index"], |
|
|
"x": vx, |
|
|
"ax": ax, |
|
|
"text": y, |
|
|
"raw_text": raw_text, |
|
|
"fps": batch["fps"].to(dtype), |
|
|
"audio_fps": batch["audio_fps"].to(dtype), |
|
|
"height": batch["height"].to(dtype), |
|
|
"width": batch["width"].to(dtype), |
|
|
"num_frames": batch["num_frames"].to(dtype), |
|
|
} |
|
|
if dataset.neg_aug: |
|
|
batch_dict.update({ |
|
|
"neg_vx": neg_vx, |
|
|
"neg_ax": neg_ax |
|
|
}) |
|
|
if dataset.require_onset: |
|
|
batch_dict.update({ |
|
|
"onset": batch["onset"].to(dtype), |
|
|
}) |
|
|
|
|
|
if save_text_features: |
|
|
with Timer("text", log=log_time): |
|
|
text_infos = text_encoder.encode(y) |
|
|
y_feat = text_infos["y"] |
|
|
y_mask = text_infos["mask"] |
|
|
|
|
|
|
|
|
|
|
|
with Timer("feature to cpu", log=log_time): |
|
|
y_feat = y_feat.cpu() |
|
|
y_mask = y_mask.cpu() |
|
|
batch_dict.update({ |
|
|
"y": y_feat, "mask": y_mask, |
|
|
}) |
|
|
|
|
|
saver.update(batch_dict) |
|
|
saver.save() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|