jaeikkim
Reinit Space without binary assets
7bfbdc3
import os
import argparse
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from typing import List
import numpy as np
import cv2
import torch
from PIL import Image
from torch.utils.data import DataLoader
import wandb
from omegaconf import OmegaConf
from training.utils import image_transform
from inference.common import (
load_train_config,
get_vq_model_image,
build_uni_prompting,
load_omada_from_checkpoint,
list_checkpoints,
grid_dict,
init_wandb,
safe_log_table,
)
def sample_video_tokens(video_path: str, vq_model_image, uni_prompting, cfg, device) -> torch.Tensor:
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames <= 0:
cap.release()
raise RuntimeError(f"No frames in {video_path}")
indices = np.linspace(0, total_frames - 1, 8, dtype=int)
frames = []
for i in range(total_frames):
ret, frame = cap.read()
if i in indices:
if not ret:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(frame)
frames.append(image_transform(pil_img, resolution=cfg.dataset.preprocessing.resolution))
cap.release()
if len(frames) < 8:
raise RuntimeError(f"Insufficient frames from {video_path}")
video_tensor = torch.stack(frames).to(device)
# offset by text tokenizer length as in training evaluation
video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer)
video_tokens = video_tokens.view(1, -1)
return video_tokens
def build_input_ids(video_tokens: torch.Tensor, question: str, uni_prompting, device) -> torch.Tensor:
spt = uni_prompting.sptids_dict
prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
input_ids = torch.cat([
spt['<|v2t|>'].to(device).unsqueeze(0),
spt['<|soi|>'].to(device).unsqueeze(0),
video_tokens,
spt['<|eoi|>'].to(device).unsqueeze(0),
spt['<|sot|>'].to(device).unsqueeze(0),
prompt_tensor
], dim=1).long()
return input_ids
def run_once(ckpt_path: str, hparams: dict, train_cfg, device):
uni_prompting, tokenizer = build_uni_prompting(train_cfg)
vq_img = get_vq_model_image(train_cfg, device)
model = load_omada_from_checkpoint(ckpt_path, device)
video_dir = hparams.get("video_dir", "/home/work/AIDAS/video/demo")
questions = hparams.get("questions", ["Please provide a detailed description of the video."])
steps = int(hparams.get("steps", 256))
block_length = int(hparams.get("block_length", 128))
max_new_tokens = int(hparams.get("max_new_tokens", 256))
# W&B
init_wandb(hparams.get("_infer_cfg", {}), "v2t", ckpt_path, {
"steps": steps,
"block_length": block_length,
"max_new_tokens": max_new_tokens,
})
files = [f for f in os.listdir(video_dir) if f.lower().endswith(".mp4")]
files.sort()
rows = []
for fname in files:
vpath = os.path.join(video_dir, fname)
try:
vtoks = sample_video_tokens(vpath, vq_img, uni_prompting, train_cfg, device)
except Exception:
continue
for q in questions:
inp = build_input_ids(vtoks, q, uni_prompting, device)
with torch.no_grad():
out_ids = model.mmu_generate(
inp,
max_new_tokens=max_new_tokens,
steps=steps,
block_length=block_length,
)
text = uni_prompting.text_tokenizer.batch_decode(
out_ids[:, inp.shape[1]:], skip_special_tokens=True
)[0]
rows.append([fname, q, text])
safe_log_table("samples/v2t", ["Video", "Question", "Caption"], rows)
wandb.finish()
def main():
parser = argparse.ArgumentParser(description="V2T Inference with CLI overrides or config grids")
parser.add_argument("--train_config", required=True)
parser.add_argument("--ckpt_root", required=True)
parser.add_argument("--infer_config", required=False)
parser.add_argument("--checkpoint", action="append", help="Repeatable: explicit checkpoint path(s). Can be '.../unwrapped_model', '.../checkpoint-XXXX', or experiment dir")
# Generation overrides
parser.add_argument("--steps", type=int)
parser.add_argument("--block_length", type=int)
parser.add_argument("--max_new_tokens", type=int)
# Dataset overrides
parser.add_argument("--video_dir")
parser.add_argument("--question", action="append", help="Repeatable: --question 'text' --question 'another'")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_cfg = load_train_config(args.train_config)
infer_cfg = {}
if args.infer_config:
infer_cfg = OmegaConf.to_container(OmegaConf.load(args.infer_config), resolve=True)
if args.checkpoint:
ckpt_list = []
for p in args.checkpoint:
ckpt_list.extend(list_checkpoints(p))
else:
ckpts = infer_cfg.get("checkpoints") if infer_cfg else None
if ckpts:
ckpt_list = []
for p in ckpts:
ckpt_list.extend(list_checkpoints(p))
else:
ckpt_list = list_checkpoints(args.ckpt_root)
if not ckpt_list:
raise FileNotFoundError(f"No checkpoints found under {args.ckpt_root} or in infer config.")
override_present = any([
args.steps is not None, args.block_length is not None, args.max_new_tokens is not None,
args.video_dir is not None, args.question is not None,
])
if override_present or not infer_cfg:
single = {
"steps": args.steps if args.steps is not None else 256,
"block_length": args.block_length if args.block_length is not None else 128,
"max_new_tokens": args.max_new_tokens if args.max_new_tokens is not None else 256,
"video_dir": args.video_dir or "/home/work/AIDAS/video/demo",
"questions": args.question or ["Please provide a detailed description of the video."],
}
single["_infer_cfg"] = infer_cfg
combos = [single]
else:
gen_grid = infer_cfg.get("generation", {
"steps": [256],
"block_length": [128],
"max_new_tokens": [256],
})
combos = grid_dict(gen_grid)
dcfg = infer_cfg.get("dataset", {
"video_dir": "/home/work/AIDAS/video/demo",
"questions": ["Please provide a detailed description of the video."],
})
if args.video_dir is not None:
dcfg["video_dir"] = args.video_dir
if args.question is not None:
dcfg["questions"] = args.question
for c in combos:
if args.steps is not None:
c["steps"] = args.steps
if args.block_length is not None:
c["block_length"] = args.block_length
if args.max_new_tokens is not None:
c["max_new_tokens"] = args.max_new_tokens
c.update(dcfg)
c["_infer_cfg"] = infer_cfg
for ckpt in ckpt_list:
for hp in combos:
run_once(ckpt, hp, train_cfg, device)
if __name__ == "__main__":
main()