Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """Utility to reproduce and debug the speech DataLoader used in training. | |
| This script pulls the speech dataset configuration from the Omada | |
| instruction-tuning config, instantiates the same `MixedSpeechTextDataset`, and | |
| iterates a configurable number of batches while measuring how long each fetch | |
| takes. Use it to spot slow or stuck samples without launching the full training | |
| job. | |
| Typical usage:: | |
| python AIDAS/MMaDA/script/debug_speech_dataloader.py \ | |
| --config AIDAS/MMaDA/configs/omada_instruction_tuning.yaml \ | |
| --flow s2t --max-batches 5 --num-workers 1 --timeout 0 | |
| Pass `--inspect-items` for a direct `dataset[idx]` sweep when a specific sample | |
| seems suspicious. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import itertools | |
| import logging | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Iterable, List | |
| from omegaconf import OmegaConf | |
| from torch.utils.data import DataLoader | |
| from MMaDA.training.data import MixedSpeechTextDataset | |
| def _collate_fn_audio(batch: List[dict[str, Any]]) -> dict[str, List[Any]]: | |
| """Match the collate function used in training for speech flows.""" | |
| return { | |
| "audio_path": [item["audio_path"] for item in batch], | |
| "text": [item["text"] for item in batch], | |
| "audio_tokens": [item.get("audio_tokens") for item in batch], | |
| } | |
| def _as_list_of_dicts(cfg_fragment: Any) -> List[dict[str, Any]]: | |
| container = OmegaConf.to_container(cfg_fragment, resolve=True) | |
| if not isinstance(container, Iterable): # pragma: no cover - sanity guard | |
| raise TypeError("audio_data config must be a list of dataset dicts") | |
| return list(container) # type: ignore[arg-type] | |
| def _build_dataset(cfg) -> MixedSpeechTextDataset: | |
| dataset_cfg = cfg.dataset.params | |
| audio_data_cfg = _as_list_of_dicts(dataset_cfg.audio_data) | |
| return MixedSpeechTextDataset(audio_data_cfg) | |
| def _log_batch_summary(idx: int, batch: dict[str, List[Any]], elapsed: float) -> None: | |
| audio_paths = batch.get("audio_path", []) | |
| sample = audio_paths[0] if audio_paths else "<empty>" | |
| logging.info( | |
| "batch=%d size=%d elapsed=%.2fs sample=%s", | |
| idx, | |
| len(audio_paths), | |
| elapsed, | |
| sample, | |
| ) | |
| def _inspect_items(dataset: MixedSpeechTextDataset, max_items: int) -> None: | |
| logging.info("Inspecting individual dataset items (max=%d)", max_items) | |
| for idx in itertools.islice(range(len(dataset)), max_items): | |
| tick = time.perf_counter() | |
| try: | |
| item = dataset[idx] | |
| except Exception as exc: # pragma: no cover - diagnostic path | |
| logging.error("idx=%d failed: %s", idx, exc) | |
| continue | |
| elapsed = time.perf_counter() - tick | |
| logging.info( | |
| "idx=%d elapsed=%.2fs path=%s text_len=%d tokens=%s", | |
| idx, | |
| elapsed, | |
| item.get("audio_path"), | |
| len(item.get("text", "")), | |
| "cached" if item.get("audio_tokens") is not None else "None", | |
| ) | |
| def parse_args(argv: List[str]) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument( | |
| "--config", | |
| type=Path, | |
| default=Path("AIDAS/MMaDA/configs/omada_instruction_tuning.yaml"), | |
| help="Path to the training config YAML", | |
| ) | |
| parser.add_argument( | |
| "--flow", | |
| choices=["s2t", "t2s"], | |
| default="s2t", | |
| help="Which speech flow's batch size defaults to use", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=None, | |
| help="Override batch size (defaults to config.training.batch_size_<flow>)", | |
| ) | |
| parser.add_argument( | |
| "--num-workers", | |
| type=int, | |
| default=None, | |
| help="Override DataLoader workers (defaults to config.dataset.params.num_workers)", | |
| ) | |
| parser.add_argument( | |
| "--persistent-workers", | |
| action="store_true", | |
| help="Enable persistent workers regardless of config", | |
| ) | |
| parser.add_argument( | |
| "--timeout", | |
| type=float, | |
| default=None, | |
| help="DataLoader timeout in seconds (defaults to config.dataset.params.dataloader_timeout)", | |
| ) | |
| parser.add_argument( | |
| "--max-batches", | |
| type=int, | |
| default=10, | |
| help="Number of batches to iterate (0 means run through the entire dataset)", | |
| ) | |
| parser.add_argument( | |
| "--inspect-items", | |
| type=int, | |
| default=0, | |
| help="If >0, bypass the DataLoader and inspect this many individual dataset items first", | |
| ) | |
| parser.add_argument( | |
| "--prefetch-factor", | |
| type=int, | |
| default=None, | |
| help="Optional override for DataLoader prefetch_factor", | |
| ) | |
| parser.add_argument( | |
| "--log-level", | |
| default="INFO", | |
| help="Logging level", | |
| ) | |
| return parser.parse_args(argv) | |
| def main(argv: List[str]) -> int: | |
| args = parse_args(argv) | |
| logging.basicConfig( | |
| level=getattr(logging, args.log_level.upper(), logging.INFO), | |
| format="%(asctime)s | %(levelname)s | %(message)s", | |
| ) | |
| cfg = OmegaConf.load(args.config) | |
| dataset = _build_dataset(cfg) | |
| if args.inspect_items: | |
| _inspect_items(dataset, args.inspect_items) | |
| dataset_params = cfg.dataset.params | |
| batch_size = args.batch_size or getattr(cfg.training, f"batch_size_{args.flow}") | |
| num_workers = args.num_workers if args.num_workers is not None else dataset_params.num_workers | |
| timeout = args.timeout if args.timeout is not None else dataset_params.dataloader_timeout | |
| if num_workers == 0: | |
| persistent_workers = False | |
| else: | |
| persistent_workers = args.persistent_workers or bool(dataset_params.persistent_workers) | |
| dataloader_kwargs = { | |
| "dataset": dataset, | |
| "batch_size": batch_size, | |
| "shuffle": False, | |
| "num_workers": num_workers, | |
| "drop_last": True, | |
| "pin_memory": bool(dataset_params.pin_memory), | |
| "timeout": timeout, | |
| "persistent_workers": persistent_workers, | |
| "collate_fn": _collate_fn_audio, | |
| } | |
| if args.prefetch_factor is not None and num_workers > 0: | |
| dataloader_kwargs["prefetch_factor"] = args.prefetch_factor | |
| logging.info( | |
| "Starting DataLoader debug: batch_size=%d num_workers=%d timeout=%s persistent=%s", | |
| batch_size, | |
| num_workers, | |
| timeout, | |
| persistent_workers, | |
| ) | |
| dataloader = DataLoader(**dataloader_kwargs) | |
| max_batches = args.max_batches | |
| iterator = iter(dataloader) | |
| processed = 0 | |
| while True: | |
| if max_batches and processed >= max_batches: | |
| break | |
| tick = time.perf_counter() | |
| try: | |
| batch = next(iterator) | |
| except StopIteration: | |
| logging.info("Reached end of DataLoader after %d batches", processed) | |
| break | |
| elapsed = time.perf_counter() - tick | |
| _log_batch_summary(processed, batch, elapsed) | |
| processed += 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main(sys.argv[1:])) | |