Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import argparse | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from typing import List | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| import wandb | |
| from omegaconf import OmegaConf | |
| from training.utils import image_transform | |
| from inference.common import ( | |
| load_train_config, | |
| get_vq_model_image, | |
| build_uni_prompting, | |
| load_omada_from_checkpoint, | |
| list_checkpoints, | |
| grid_dict, | |
| init_wandb, | |
| safe_log_table, | |
| ) | |
| def sample_video_tokens(video_path: str, vq_model_image, uni_prompting, cfg, device) -> torch.Tensor: | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames <= 0: | |
| cap.release() | |
| raise RuntimeError(f"No frames in {video_path}") | |
| indices = np.linspace(0, total_frames - 1, 8, dtype=int) | |
| frames = [] | |
| for i in range(total_frames): | |
| ret, frame = cap.read() | |
| if i in indices: | |
| if not ret: | |
| continue | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(frame) | |
| frames.append(image_transform(pil_img, resolution=cfg.dataset.preprocessing.resolution)) | |
| cap.release() | |
| if len(frames) < 8: | |
| raise RuntimeError(f"Insufficient frames from {video_path}") | |
| video_tensor = torch.stack(frames).to(device) | |
| # offset by text tokenizer length as in training evaluation | |
| video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) | |
| video_tokens = video_tokens.view(1, -1) | |
| return video_tokens | |
| def build_input_ids(video_tokens: torch.Tensor, question: str, uni_prompting, device) -> torch.Tensor: | |
| spt = uni_prompting.sptids_dict | |
| prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' | |
| prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) | |
| input_ids = torch.cat([ | |
| spt['<|v2t|>'].to(device).unsqueeze(0), | |
| spt['<|soi|>'].to(device).unsqueeze(0), | |
| video_tokens, | |
| spt['<|eoi|>'].to(device).unsqueeze(0), | |
| spt['<|sot|>'].to(device).unsqueeze(0), | |
| prompt_tensor | |
| ], dim=1).long() | |
| return input_ids | |
| def run_once(ckpt_path: str, hparams: dict, train_cfg, device): | |
| uni_prompting, tokenizer = build_uni_prompting(train_cfg) | |
| vq_img = get_vq_model_image(train_cfg, device) | |
| model = load_omada_from_checkpoint(ckpt_path, device) | |
| video_dir = hparams.get("video_dir", "/home/work/AIDAS/video/demo") | |
| questions = hparams.get("questions", ["Please provide a detailed description of the video."]) | |
| steps = int(hparams.get("steps", 256)) | |
| block_length = int(hparams.get("block_length", 128)) | |
| max_new_tokens = int(hparams.get("max_new_tokens", 256)) | |
| # W&B | |
| init_wandb(hparams.get("_infer_cfg", {}), "v2t", ckpt_path, { | |
| "steps": steps, | |
| "block_length": block_length, | |
| "max_new_tokens": max_new_tokens, | |
| }) | |
| files = [f for f in os.listdir(video_dir) if f.lower().endswith(".mp4")] | |
| files.sort() | |
| rows = [] | |
| for fname in files: | |
| vpath = os.path.join(video_dir, fname) | |
| try: | |
| vtoks = sample_video_tokens(vpath, vq_img, uni_prompting, train_cfg, device) | |
| except Exception: | |
| continue | |
| for q in questions: | |
| inp = build_input_ids(vtoks, q, uni_prompting, device) | |
| with torch.no_grad(): | |
| out_ids = model.mmu_generate( | |
| inp, | |
| max_new_tokens=max_new_tokens, | |
| steps=steps, | |
| block_length=block_length, | |
| ) | |
| text = uni_prompting.text_tokenizer.batch_decode( | |
| out_ids[:, inp.shape[1]:], skip_special_tokens=True | |
| )[0] | |
| rows.append([fname, q, text]) | |
| safe_log_table("samples/v2t", ["Video", "Question", "Caption"], rows) | |
| wandb.finish() | |
| def main(): | |
| parser = argparse.ArgumentParser(description="V2T Inference with CLI overrides or config grids") | |
| parser.add_argument("--train_config", required=True) | |
| parser.add_argument("--ckpt_root", required=True) | |
| parser.add_argument("--infer_config", required=False) | |
| parser.add_argument("--checkpoint", action="append", help="Repeatable: explicit checkpoint path(s). Can be '.../unwrapped_model', '.../checkpoint-XXXX', or experiment dir") | |
| # Generation overrides | |
| parser.add_argument("--steps", type=int) | |
| parser.add_argument("--block_length", type=int) | |
| parser.add_argument("--max_new_tokens", type=int) | |
| # Dataset overrides | |
| parser.add_argument("--video_dir") | |
| parser.add_argument("--question", action="append", help="Repeatable: --question 'text' --question 'another'") | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| train_cfg = load_train_config(args.train_config) | |
| infer_cfg = {} | |
| if args.infer_config: | |
| infer_cfg = OmegaConf.to_container(OmegaConf.load(args.infer_config), resolve=True) | |
| if args.checkpoint: | |
| ckpt_list = [] | |
| for p in args.checkpoint: | |
| ckpt_list.extend(list_checkpoints(p)) | |
| else: | |
| ckpts = infer_cfg.get("checkpoints") if infer_cfg else None | |
| if ckpts: | |
| ckpt_list = [] | |
| for p in ckpts: | |
| ckpt_list.extend(list_checkpoints(p)) | |
| else: | |
| ckpt_list = list_checkpoints(args.ckpt_root) | |
| if not ckpt_list: | |
| raise FileNotFoundError(f"No checkpoints found under {args.ckpt_root} or in infer config.") | |
| override_present = any([ | |
| args.steps is not None, args.block_length is not None, args.max_new_tokens is not None, | |
| args.video_dir is not None, args.question is not None, | |
| ]) | |
| if override_present or not infer_cfg: | |
| single = { | |
| "steps": args.steps if args.steps is not None else 256, | |
| "block_length": args.block_length if args.block_length is not None else 128, | |
| "max_new_tokens": args.max_new_tokens if args.max_new_tokens is not None else 256, | |
| "video_dir": args.video_dir or "/home/work/AIDAS/video/demo", | |
| "questions": args.question or ["Please provide a detailed description of the video."], | |
| } | |
| single["_infer_cfg"] = infer_cfg | |
| combos = [single] | |
| else: | |
| gen_grid = infer_cfg.get("generation", { | |
| "steps": [256], | |
| "block_length": [128], | |
| "max_new_tokens": [256], | |
| }) | |
| combos = grid_dict(gen_grid) | |
| dcfg = infer_cfg.get("dataset", { | |
| "video_dir": "/home/work/AIDAS/video/demo", | |
| "questions": ["Please provide a detailed description of the video."], | |
| }) | |
| if args.video_dir is not None: | |
| dcfg["video_dir"] = args.video_dir | |
| if args.question is not None: | |
| dcfg["questions"] = args.question | |
| for c in combos: | |
| if args.steps is not None: | |
| c["steps"] = args.steps | |
| if args.block_length is not None: | |
| c["block_length"] = args.block_length | |
| if args.max_new_tokens is not None: | |
| c["max_new_tokens"] = args.max_new_tokens | |
| c.update(dcfg) | |
| c["_infer_cfg"] = infer_cfg | |
| for ckpt in ckpt_list: | |
| for hp in combos: | |
| run_once(ckpt, hp, train_cfg, device) | |
| if __name__ == "__main__": | |
| main() | |