Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| Utility script to sanity-check data loaders defined in train_omada_inst.py | |
| without constructing the full training stack. | |
| Example: | |
| python MMaDA/tools/run_dataloaders.py config=MMaDA/configs/omada_instruction_tuning.yaml \ | |
| --flows v2t --num-workers 0 --max-batches 10 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from typing import Any, Dict, Iterable, List, Optional, Tuple | |
| import torch | |
| from omegaconf import DictConfig, OmegaConf | |
| from torch.utils.data import DataLoader | |
| from transformers import AutoTokenizer | |
| # Ensure repository root is importable when executing from arbitrary cwd. | |
| REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| if REPO_ROOT not in sys.path: | |
| sys.path.insert(0, REPO_ROOT) | |
| from training.data import VideoCaptionDataset # noqa: E402 | |
| from training.utils import image_transform # noqa: E402 | |
| LOGGER = logging.getLogger("run_dataloaders") | |
| def _parse_args() -> Tuple[argparse.Namespace, DictConfig]: | |
| parser = argparse.ArgumentParser(description="Run Omada dataloaders without the trainer.") | |
| parser.add_argument( | |
| "--flows", | |
| default="v2t", | |
| help="Comma separated list of dataloaders to exercise (currently supports: v2t). " | |
| "Use 'all' to run every available flow.", | |
| ) | |
| parser.add_argument( | |
| "--max-batches", | |
| type=int, | |
| default=0, | |
| help="Stop after this many batches per loader (0 means iterate the entire epoch).", | |
| ) | |
| parser.add_argument( | |
| "--num-workers", | |
| type=int, | |
| default=None, | |
| help="Override DataLoader num_workers (falls back to config.dataset.params.num_workers).", | |
| ) | |
| parser.add_argument( | |
| "--persistent-workers", | |
| dest="persistent_workers", | |
| action="store_true", | |
| help="Force persistent_workers=True regardless of config.", | |
| ) | |
| parser.add_argument( | |
| "--no-persistent-workers", | |
| dest="persistent_workers", | |
| action="store_false", | |
| help="Force persistent_workers=False regardless of config.", | |
| ) | |
| parser.set_defaults(persistent_workers=None) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=42, | |
| help="Torch manual seed for reproducibility.", | |
| ) | |
| args, unknown = parser.parse_known_args() | |
| cli_conf = OmegaConf.from_cli(unknown) | |
| if "config" not in cli_conf: | |
| parser.error("Please provide the training config via 'config=/path/to/config.yaml'.") | |
| yaml_conf = OmegaConf.load(cli_conf.config) | |
| merged = OmegaConf.merge(yaml_conf, cli_conf) | |
| return args, merged | |
| def _collate_v2t(batch: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: | |
| """Minimal collate fn mirroring train_omada_inst.collate_fn_v2t.""" | |
| filtered: List[Dict[str, Any]] = [sample for sample in batch if sample is not None] | |
| if not filtered: | |
| return None | |
| videos: List[torch.Tensor] = [] | |
| captions: List[Any] = [] | |
| for sample in filtered: | |
| frames = sample.get("video") | |
| caption = sample.get("caption") | |
| if frames is None: | |
| continue | |
| try: | |
| tensor = torch.stack(frames, dim=0) | |
| except Exception as exc: | |
| LOGGER.exception("Failed to stack frames for sample %s", sample) | |
| raise exc | |
| videos.append(tensor) | |
| captions.append(caption) | |
| if not videos: | |
| return None | |
| return { | |
| "video": torch.stack(videos, dim=0), | |
| "captions": captions, | |
| } | |
| def _build_v2t_loader( | |
| cfg: DictConfig, | |
| tokenizer, | |
| *, | |
| num_workers: int, | |
| persistent_workers: bool, | |
| pin_memory: bool, | |
| ) -> DataLoader: | |
| speech_cfg = getattr(cfg.dataset.params, "video_speech_dataset", {}) | |
| if not isinstance(speech_cfg, dict): | |
| speech_cfg = OmegaConf.to_container(speech_cfg, resolve=True) | |
| speech_cfg = speech_cfg or {} | |
| dataset = VideoCaptionDataset( | |
| transform=image_transform, | |
| tokenizer=tokenizer, | |
| max_seq_length=int(cfg.dataset.preprocessing.max_seq_length), | |
| resolution=int(cfg.dataset.preprocessing.resolution), | |
| sample_method=speech_cfg.get("sample_method", "uniform"), | |
| dataset_name=speech_cfg.get("llavavid_dataset_name", "llavavid"), | |
| num_frames=int(speech_cfg.get("num_frames", 8)), | |
| ) | |
| batch_size = int(max(1, cfg.training.batch_size_v2t)) | |
| LOGGER.info( | |
| "Instantiated VideoCaptionDataset with %d samples; batch_size=%d num_workers=%d", | |
| len(dataset), | |
| batch_size, | |
| num_workers, | |
| ) | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| persistent_workers=persistent_workers if num_workers > 0 else False, | |
| collate_fn=_collate_v2t, | |
| drop_last=False, | |
| ) | |
| def _iterate_loader(name: str, loader: DataLoader, max_batches: int) -> None: | |
| LOGGER.info("Starting iteration over '%s' (max_batches=%s)", name, max_batches or "full epoch") | |
| start = time.time() | |
| failures = 0 | |
| processed = 0 | |
| try: | |
| for step, batch in enumerate(loader, start=1): | |
| if batch is None: | |
| failures += 1 | |
| LOGGER.warning("[%s] Received empty batch at step %d", name, step) | |
| continue | |
| processed += batch["video"].size(0) | |
| if max_batches and step >= max_batches: | |
| break | |
| except Exception as exc: | |
| LOGGER.exception("Loader '%s' raised an exception at batch %d", name, step) | |
| raise exc | |
| finally: | |
| duration = time.time() - start | |
| LOGGER.info( | |
| "Finished '%s': steps=%d samples=%d failures=%d elapsed=%.2fs", | |
| name, | |
| step if 'step' in locals() else 0, | |
| processed, | |
| failures, | |
| duration, | |
| ) | |
| def main() -> None: | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| args, cfg = _parse_args() | |
| torch.manual_seed(args.seed) | |
| pin_memory = bool(getattr(cfg.dataset.params, "pin_memory", False)) | |
| if args.num_workers is None: | |
| num_workers = int(getattr(cfg.dataset.params, "num_workers", 0)) | |
| else: | |
| num_workers = max(0, args.num_workers) | |
| if args.persistent_workers is None: | |
| persistent_workers = bool(getattr(cfg.dataset.params, "persistent_workers", False)) | |
| else: | |
| persistent_workers = bool(args.persistent_workers) | |
| flows_arg = [item.strip().lower() for item in args.flows.split(",") if item.strip()] | |
| if "all" in flows_arg: | |
| flows = {"v2t"} | |
| else: | |
| flows = set(flows_arg) | |
| tokenizer = AutoTokenizer.from_pretrained(cfg.model.omada.tokenizer_path, padding_side="left") | |
| loaders: Dict[str, DataLoader] = {} | |
| if "v2t" in flows: | |
| loaders["v2t"] = _build_v2t_loader( | |
| cfg, | |
| tokenizer, | |
| num_workers=num_workers, | |
| persistent_workers=persistent_workers, | |
| pin_memory=pin_memory, | |
| ) | |
| if not loaders: | |
| LOGGER.error("No loaders selected. Supported flows: v2t") | |
| sys.exit(1) | |
| for name, loader in loaders.items(): | |
| _iterate_loader(name, loader, args.max_batches) | |
| if __name__ == "__main__": | |
| main() | |