AIDAS-Omni-Modal-Diffusion / MMaDA /script /debug_llavavid_decode.py
jaeikkim
Reinit Space without binary assets
7bfbdc3
#!/usr/bin/env python
"""
Utility script to stress-test LLaVA-Video frame decoding in isolation.
This runs the `VideoCaptionDataset` loader on a single node so that we can
watch for files that consistently time out or wedged dataloader workers.
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
import torch
from torch.utils.data import DataLoader
ROOT_DIR = Path(__file__).resolve().parents[2]
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))
from training import data as video_data_module # noqa: E402
from training.data import VideoCaptionDataset # noqa: E402
from training.utils import image_transform as default_image_transform # noqa: E402
def _resolve_llavavid_root(root_arg: Optional[str]) -> Path:
if root_arg:
root = Path(root_arg).expanduser().resolve()
else:
root = ROOT_DIR / "data" / "video" / "LLaVA-Video-178K"
if not root.exists():
raise FileNotFoundError(f"LLaVA-Video root directory not found: {root}")
return root
def _identity_collate(batch: List[Optional[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""Drop `None` samples that VideoCaptionDataset returns after repeated failures."""
filtered = [sample for sample in batch if sample is not None]
return filtered
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Decode-check LLaVA-Video samples with the existing dataset logic."
)
parser.add_argument(
"--llavavid-root",
type=str,
default=None,
help="Path to the LLaVA-Video-178K cache directory. Defaults to data/video/LLaVA-Video-178K relative to repo root.",
)
parser.add_argument(
"--num-samples",
type=int,
default=256,
help=(
"Number of samples to attempt decoding (per DataLoader worker collectively). "
"Set to -1 to sweep the entire dataset once."
),
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size for the diagnostic DataLoader.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of DataLoader workers to spawn. Set to match your training run.",
)
parser.add_argument(
"--num-frames",
type=int,
default=8,
help="Number of frames to request from load_video_mp4.",
)
parser.add_argument(
"--resolution",
type=int,
default=256,
help="Resolution passed to the dataset transform.",
)
parser.add_argument(
"--sample-method",
type=str,
default="uniform",
choices=("uniform", "random"),
help="Frame sampling strategy.",
)
parser.add_argument(
"--report-every",
type=int,
default=10,
help="Print a progress line every N successfully decoded samples.",
)
parser.add_argument(
"--timeout",
type=float,
default=30.0,
help="Maximum seconds to allow a batch to hang before treating it as a stall.",
)
return parser.parse_args()
def _maybe_set_thread_limits() -> None:
# Avoid oversubscribing CPU threads when the loader uses multiple workers.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
def main() -> None:
args = _parse_args()
_maybe_set_thread_limits()
llavavid_root = _resolve_llavavid_root(args.llavavid_root)
print(f"[INFO] Using LLaVA-Video root: {llavavid_root}")
original_loader = video_data_module.load_video_mp4
def traced_loader(*loader_args, **loader_kwargs):
video_path = loader_kwargs.get("video_path")
if video_path is None and loader_args:
video_path = loader_args[0]
start = time.time()
try:
frames = original_loader(*loader_args, **loader_kwargs)
except Exception as exc: # pylint: disable=broad-except
duration = time.time() - start
print(f"[ERROR] {video_path} raised {exc.__class__.__name__} after {duration:.2f}s: {exc}")
raise
duration = time.time() - start
status = "OK" if frames else "NONE"
print(f"[TRACE] {status:>4} | {duration:6.2f}s | {video_path}")
return frames
video_data_module.load_video_mp4 = traced_loader
try:
dataset = VideoCaptionDataset(
transform=default_image_transform,
tokenizer=None,
max_seq_length=256,
resolution=args.resolution,
dataset_name="llavavid",
llavavid_path=str(llavavid_root),
llavavid_local_files_only=True,
sample_method=args.sample_method,
num_frames=args.num_frames,
)
if len(dataset) == 0:
print("[ERROR] Dataset returned zero length. Check the root directory/config.")
sys.exit(1)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=_identity_collate,
pin_memory=False,
drop_last=False,
)
print(
f"[INFO] Starting decode sweep: "
f"{args.num_samples} samples, batch_size={args.batch_size}, num_workers={args.num_workers}"
)
decoded = 0
attempted = 0
failed = 0
start_time = time.time()
last_report = start_time
for batch_idx, batch in enumerate(dataloader, start=1):
expected = args.batch_size
actual = len(batch)
attempted += expected
failed += max(expected - actual, 0)
decoded += sum(1 for sample in batch if sample.get("video"))
if args.num_samples > 0 and decoded >= args.num_samples:
break
now = time.time()
if args.report_every > 0 and decoded and decoded % args.report_every == 0:
elapsed = now - last_report
total_elapsed = now - start_time
print(
f"[INFO] {decoded} successful samples "
f"(attempted={attempted}, failed={failed}) "
f"in {total_elapsed:.1f}s (+{elapsed:.1f}s since last report)."
)
last_report = now
if now - start_time > args.timeout:
print(
f"[WARN] Exceeded timeout of {args.timeout}s without reaching target samples."
)
break
total_elapsed = time.time() - start_time
print(
f"[RESULT] Completed sweep: decoded={decoded}, attempted={attempted}, "
f"failed={failed}, elapsed={total_elapsed:.1f}s."
)
finally:
video_data_module.load_video_mp4 = original_loader
if __name__ == "__main__":
main()