#!/usr/bin/env python """ Utility script to stress-test LLaVA-Video frame decoding in isolation. This runs the `VideoCaptionDataset` loader on a single node so that we can watch for files that consistently time out or wedged dataloader workers. """ from __future__ import annotations import argparse import os import sys import time from pathlib import Path from typing import Any, Dict, Iterable, List, Optional import torch from torch.utils.data import DataLoader ROOT_DIR = Path(__file__).resolve().parents[2] if str(ROOT_DIR) not in sys.path: sys.path.insert(0, str(ROOT_DIR)) from training import data as video_data_module # noqa: E402 from training.data import VideoCaptionDataset # noqa: E402 from training.utils import image_transform as default_image_transform # noqa: E402 def _resolve_llavavid_root(root_arg: Optional[str]) -> Path: if root_arg: root = Path(root_arg).expanduser().resolve() else: root = ROOT_DIR / "data" / "video" / "LLaVA-Video-178K" if not root.exists(): raise FileNotFoundError(f"LLaVA-Video root directory not found: {root}") return root def _identity_collate(batch: List[Optional[Dict[str, Any]]]) -> List[Dict[str, Any]]: """Drop `None` samples that VideoCaptionDataset returns after repeated failures.""" filtered = [sample for sample in batch if sample is not None] return filtered def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Decode-check LLaVA-Video samples with the existing dataset logic." ) parser.add_argument( "--llavavid-root", type=str, default=None, help="Path to the LLaVA-Video-178K cache directory. Defaults to data/video/LLaVA-Video-178K relative to repo root.", ) parser.add_argument( "--num-samples", type=int, default=256, help=( "Number of samples to attempt decoding (per DataLoader worker collectively). " "Set to -1 to sweep the entire dataset once." ), ) parser.add_argument( "--batch-size", type=int, default=1, help="Batch size for the diagnostic DataLoader.", ) parser.add_argument( "--num-workers", type=int, default=4, help="Number of DataLoader workers to spawn. Set to match your training run.", ) parser.add_argument( "--num-frames", type=int, default=8, help="Number of frames to request from load_video_mp4.", ) parser.add_argument( "--resolution", type=int, default=256, help="Resolution passed to the dataset transform.", ) parser.add_argument( "--sample-method", type=str, default="uniform", choices=("uniform", "random"), help="Frame sampling strategy.", ) parser.add_argument( "--report-every", type=int, default=10, help="Print a progress line every N successfully decoded samples.", ) parser.add_argument( "--timeout", type=float, default=30.0, help="Maximum seconds to allow a batch to hang before treating it as a stall.", ) return parser.parse_args() def _maybe_set_thread_limits() -> None: # Avoid oversubscribing CPU threads when the loader uses multiple workers. os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") def main() -> None: args = _parse_args() _maybe_set_thread_limits() llavavid_root = _resolve_llavavid_root(args.llavavid_root) print(f"[INFO] Using LLaVA-Video root: {llavavid_root}") original_loader = video_data_module.load_video_mp4 def traced_loader(*loader_args, **loader_kwargs): video_path = loader_kwargs.get("video_path") if video_path is None and loader_args: video_path = loader_args[0] start = time.time() try: frames = original_loader(*loader_args, **loader_kwargs) except Exception as exc: # pylint: disable=broad-except duration = time.time() - start print(f"[ERROR] {video_path} raised {exc.__class__.__name__} after {duration:.2f}s: {exc}") raise duration = time.time() - start status = "OK" if frames else "NONE" print(f"[TRACE] {status:>4} | {duration:6.2f}s | {video_path}") return frames video_data_module.load_video_mp4 = traced_loader try: dataset = VideoCaptionDataset( transform=default_image_transform, tokenizer=None, max_seq_length=256, resolution=args.resolution, dataset_name="llavavid", llavavid_path=str(llavavid_root), llavavid_local_files_only=True, sample_method=args.sample_method, num_frames=args.num_frames, ) if len(dataset) == 0: print("[ERROR] Dataset returned zero length. Check the root directory/config.") sys.exit(1) dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=_identity_collate, pin_memory=False, drop_last=False, ) print( f"[INFO] Starting decode sweep: " f"{args.num_samples} samples, batch_size={args.batch_size}, num_workers={args.num_workers}" ) decoded = 0 attempted = 0 failed = 0 start_time = time.time() last_report = start_time for batch_idx, batch in enumerate(dataloader, start=1): expected = args.batch_size actual = len(batch) attempted += expected failed += max(expected - actual, 0) decoded += sum(1 for sample in batch if sample.get("video")) if args.num_samples > 0 and decoded >= args.num_samples: break now = time.time() if args.report_every > 0 and decoded and decoded % args.report_every == 0: elapsed = now - last_report total_elapsed = now - start_time print( f"[INFO] {decoded} successful samples " f"(attempted={attempted}, failed={failed}) " f"in {total_elapsed:.1f}s (+{elapsed:.1f}s since last report)." ) last_report = now if now - start_time > args.timeout: print( f"[WARN] Exceeded timeout of {args.timeout}s without reaching target samples." ) break total_elapsed = time.time() - start_time print( f"[RESULT] Completed sweep: decoded={decoded}, attempted={attempted}, " f"failed={failed}, elapsed={total_elapsed:.1f}s." ) finally: video_data_module.load_video_mp4 = original_loader if __name__ == "__main__": main()