JAV-Gen / scripts /misc /extract_feat.py
kaiw7's picture
Upload folder using huggingface_hub
e490e7e verified
import os
from pprint import pformat
from itertools import islice
import colossalai
import torch
import torch.distributed as dist
from tqdm import tqdm
from javisdit.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
from javisdit.datasets.dataloader import prepare_dataloader
from javisdit.datasets.datasets import VariableVideoTextDataset
from javisdit.registry import DATASETS, MODELS, build_module
from javisdit.utils.config_utils import parse_configs, save_training_config
from javisdit.utils.misc import FeatureSaver, Timer, create_logger, to_torch_dtype
def main():
torch.set_grad_enabled(False)
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=False)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg_dtype = cfg.get("dtype", "fp32")
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
colossalai.launch_from_torch({})
set_data_parallel_group(dist.group.WORLD)
# == init logger, tensorboard & wandb ==
logger = create_logger()
logger.info("Configuration:\n %s", pformat(cfg.to_dict()))
# ======================================================
# 2. build dataset and dataloader
# ======================================================
# == global variables ==
bin_size = cfg.bin_size
save_text_features = cfg.get("save_text_features", False)
save_compressed_text_features = cfg.get("save_compressed_text_features", False)
if save_compressed_text_features:
raise NotImplementedError
save_text_only = cfg.get("save_text_only", False)
# resume from a specific batch index
start_index = cfg.get("start_index", 0)
end_index = cfg.get("end_index")
last_micro_batch_access_index = start_index * bin_size
start_step = 0
# create save directory
assert cfg.get("save_dir", None) is not None, "Please specify the save_dir in the config file."
save_dir = os.path.join(cfg.save_dir, f"s{start_index}_e{end_index}")
os.makedirs(save_dir, exist_ok=True)
save_training_config(cfg.to_dict(), save_dir)
logger.info("Saving features to %s", save_dir)
saver = FeatureSaver(save_dir, bin_size, start_bin=start_index)
start_step = saver.get_num_saved()
if start_step > 0:
logger.info(f'Found existing data. Start from step {start_step}.')
last_micro_batch_access_index += start_step
logger.info("Building dataset...")
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS, audio_cfg=cfg.get("audio_cfg"))
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=1,
sampler_kwargs={}
)
assert isinstance(dataset, VariableVideoTextDataset)
dataloader_args['sampler_kwargs'] = {'last_micro_batch_access_index': last_micro_batch_access_index}
dataloader, _ = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = len(dataloader)
# dataloader.batch_sampler.load_state_dict({"last_micro_batch_access_index": start_index})
# == number of bins ==
num_bin = num_steps_per_epoch // bin_size
logger.info("Number of batches: %s", num_steps_per_epoch)
logger.info("Bin size: %s", bin_size)
logger.info("Number of bins: %s", num_bin)
num_bin_to_process = min(num_bin, end_index) - start_index
logger.info("Start index: %s", start_index)
logger.info("End index: %s", end_index)
logger.info("Number of batches to process: %s", num_bin_to_process)
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.get('text_encoder', None), MODELS, device=device, dtype=dtype)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
audio_vae = build_module(cfg.audio_vae, MODELS, device=device, dtype=dtype)
# =======================================================
# 5. training loop
# =======================================================
# == training loop in an epoch ==
dataloader_iter = iter(dataloader)
log_time = cfg.get("log_time", False)
total_steps = num_bin_to_process * bin_size
for _ in tqdm(range(start_step, total_steps), initial=start_step, total=total_steps):
with Timer("step", log=log_time):
with Timer("data loading", log=log_time):
batch = next(dataloader_iter)
neg_vx, neg_ax = None, None
if not save_text_only:
vx = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
ax = batch.pop("audio").to(device, dtype) # [B, C, T, M]
if dataset.neg_aug:
neg_vx = {aug_type: aug_vx.flatten(0, 1).to(device, dtype) \
for aug_type, aug_vx in batch.pop('neg_videos').items()}
neg_ax = {aug_type: aug_ax.flatten(0, 1).to(device, dtype) \
for aug_type, aug_ax in batch.pop('neg_audios').items()}
else:
vx, ax = None, None
y = batch.get("text")
raw_text = batch.get('raw_text', batch.get("text"))
batch_num_frames = batch['num_frames']
batch_fps = batch['fps']
batch_duration = batch_num_frames / batch_fps
assert len(torch.unique(batch_duration)) == 1, 'variable durations temporally unsupported'
if not save_text_only:
with Timer("vae", log=log_time):
if neg_vx is not None:
size_list = [vx.shape[0], *[v.shape[0] for v in neg_vx.values()]]
bs, neg_num = vx.shape[0], dataset.neg_aug
for x, neg_x, encode_func in \
[[vx, neg_vx, vae.encode], [ax, neg_ax, audio_vae.encode_audio]]:
x = torch.cat([x, *list(neg_x.values())], dim=0)
x = encode_func(x)
x_list = x.split(size_list, dim=0)
dims = x_list[0].shape[1:]
for i, aug_type in enumerate(neg_x.keys()):
neg_x[aug_type] = x_list[i+1].view(bs, neg_num, *dims)
neg_x['raw'] = x_list[0]
vx, ax = neg_vx.pop('raw'), neg_ax.pop('raw')
else:
vx = vae.encode(vx)
ax = audio_vae.encode_audio(ax) # [B, C, T, M]
with Timer("feature to cpu", log=log_time):
vx = vx.cpu()
ax = ax.cpu()
if dataset.neg_aug:
neg_vx = {k: v.cpu() for k, v in neg_vx.items()}
neg_ax = {k: v.cpu() for k, v in neg_ax.items()}
batch_dict = {
"index": batch["index"],
"x": vx,
"ax": ax,
"text": y,
"raw_text": raw_text,
"fps": batch["fps"].to(dtype),
"audio_fps": batch["audio_fps"].to(dtype),
"height": batch["height"].to(dtype),
"width": batch["width"].to(dtype),
"num_frames": batch["num_frames"].to(dtype),
}
if dataset.neg_aug:
batch_dict.update({
"neg_vx": neg_vx,
"neg_ax": neg_ax
})
if dataset.require_onset:
batch_dict.update({
"onset": batch["onset"].to(dtype),
})
if save_text_features:
with Timer("text", log=log_time):
text_infos = text_encoder.encode(y)
y_feat = text_infos["y"]
y_mask = text_infos["mask"]
# if save_compressed_text_features:
# y_feat, y_mask = model.encode_text(y_feat, y_mask)
# y_mask = torch.tensor(y_mask)
with Timer("feature to cpu", log=log_time):
y_feat = y_feat.cpu()
y_mask = y_mask.cpu()
batch_dict.update({
"y": y_feat, "mask": y_mask,
})
saver.update(batch_dict)
saver.save()
if __name__ == "__main__":
main()