Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| Pre-compute EMOVA speech tokenizer codes for audio datasets. | |
| Supported dataset types: | |
| - video-speech : CSV index with truncated WAV clips (e.g., OpenVid speech) | |
| - librispeech : LibriSpeech directory structure with FLAC audio | |
| - instructs2s : InstructS2S-200K style user/assistant WAV pairs | |
| Examples | |
| -------- | |
| # VideoSpeech | |
| python MMaDA/precompute_video_speech_tokens.py \\ | |
| --dataset-type video-speech \\ | |
| --index /home/work/AIDAS/data/video-speech/openvid-speech.csv \\ | |
| --audio-root /home/work/AIDAS/data/video-speech/openvid-speech-trunc \\ | |
| --output /home/work/AIDAS/cache/video_speech_tokens | |
| # LibriSpeech | |
| python MMaDA/precompute_video_speech_tokens.py \\ | |
| --dataset-type librispeech \\ | |
| --audio-root /home/work/AIDAS/data/audio/LibriSpeech \\ | |
| --librispeech-subsets train-clean-360 train-clean-100 \\ | |
| --output /home/work/AIDAS/cache/librispeech_tokens | |
| # InstructS2S (pairs.txt assumed under audio root) | |
| python MMaDA/precompute_video_speech_tokens.py \\ | |
| --dataset-type instructs2s \\ | |
| --audio-root /home/work/AIDAS/data/InstructS2S-200K/en/wav \\ | |
| --output /home/work/AIDAS/cache/instructs2s_tokens | |
| """ | |
| import argparse | |
| import csv | |
| import hashlib | |
| import os | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Iterable, Iterator, List, Set | |
| import soundfile as sf | |
| import torch | |
| from tqdm import tqdm | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer # noqa: E402 | |
| def iter_video_speech_audio(index_path: Path, audio_root: Path) -> Iterator[Path]: | |
| """ | |
| Yields audio paths from the VideoSpeech index CSV. | |
| """ | |
| with index_path.open("r", newline="") as csvfile: | |
| reader = csv.reader(csvfile) | |
| for row in reader: | |
| if not row: | |
| continue | |
| base = row[0].strip().removesuffix(".wav") | |
| if not base: | |
| continue | |
| audio_path = audio_root / f"{base}.wav" | |
| if audio_path.is_file(): | |
| yield audio_path | |
| def iter_librispeech_audio(audio_root: Path, subsets: Iterable[str]) -> Iterator[Path]: | |
| """ | |
| Iterates through LibriSpeech FLAC files for the provided subsets. | |
| """ | |
| for subset in subsets: | |
| subset_dir = audio_root / subset | |
| if not subset_dir.exists(): | |
| raise FileNotFoundError(f"LibriSpeech subset not found: {subset_dir}") | |
| speakers = sorted(p for p in subset_dir.iterdir() if p.is_dir()) | |
| for speaker_dir in speakers: | |
| chapters = sorted(p for p in speaker_dir.iterdir() if p.is_dir()) | |
| for chapter_dir in chapters: | |
| for flac_path in sorted(chapter_dir.glob("*.flac")): | |
| yield flac_path | |
| def iter_instructs2s_audio(audio_root: Path, pairs_file: Path | None = None) -> Iterator[Path]: | |
| """ | |
| Yields unique audio paths from an InstructS2S root directory. | |
| If pairs_file is provided (or found under audio_root), it's expected to contain | |
| two space-separated paths per line: user assistant. | |
| Otherwise, the directory tree is scanned similarly to Speech2SpeechDataset. | |
| """ | |
| resolved_root = audio_root.expanduser().resolve() | |
| if pairs_file is None: | |
| candidate = resolved_root / "pairs.txt" | |
| if candidate.exists(): | |
| pairs_file = candidate | |
| if pairs_file is not None: | |
| with Path(pairs_file).open("r") as fh: | |
| for line in fh: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| parts = line.split() | |
| if len(parts) >= 2: | |
| user_path = Path(parts[0]) | |
| if not user_path.is_absolute(): | |
| user_path = resolved_root / user_path | |
| asst_path = Path(parts[1]) | |
| if not asst_path.is_absolute(): | |
| asst_path = resolved_root / asst_path | |
| if user_path.is_file(): | |
| yield user_path | |
| if asst_path.is_file(): | |
| yield asst_path | |
| return | |
| dirs = [p for p in resolved_root.glob("*") if p.is_dir()] | |
| for dir_path in dirs: | |
| dir_name = dir_path.name | |
| k = 1 | |
| while True: | |
| user_wav = dir_path / f"{dir_name}-{k}-user.wav" | |
| assistant_wav = dir_path / f"{dir_name}-{k}-assistant.wav" | |
| if user_wav.is_file() and assistant_wav.is_file(): | |
| yield user_wav | |
| yield assistant_wav | |
| k += 1 | |
| continue | |
| break | |
| def hash_path(path: Path) -> str: | |
| """Returns a SHA-1 hex digest for the absolute path.""" | |
| abs_path = os.path.abspath(path) | |
| return hashlib.sha1(abs_path.encode("utf-8")).hexdigest() | |
| def token_output_path(output_root: Path, audio_path: Path) -> Path: | |
| """Resolves the on-disk location for cached tokens corresponding to audio_path.""" | |
| digest = hash_path(audio_path) | |
| return output_root / digest[:2] / digest[2:4] / f"{digest}.pt" | |
| def encode_audio(tokenizer: EMOVASpeechTokenizer, audio_path: Path) -> torch.Tensor: | |
| """ | |
| Encodes an audio file to discrete tokens, converting non-WAV inputs on the fly. | |
| """ | |
| suffix = audio_path.suffix.lower() | |
| if suffix == ".wav": | |
| return tokenizer.encode(str(audio_path)).cpu() | |
| data, sample_rate = sf.read(str(audio_path)) | |
| tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| try: | |
| sf.write(tmp_file.name, data, sample_rate) | |
| tokens = tokenizer.encode(tmp_file.name).cpu() | |
| finally: | |
| tmp_file.close() | |
| try: | |
| os.remove(tmp_file.name) | |
| except OSError: | |
| pass | |
| return tokens | |
| def gather_audio_paths(args) -> List[Path]: | |
| if args.dataset_type == "video-speech": | |
| return list(iter_video_speech_audio(args.index, args.audio_root)) | |
| if args.dataset_type == "librispeech": | |
| return list(iter_librispeech_audio(args.audio_root, args.librispeech_subsets)) | |
| # instructs2s | |
| paths = list(iter_instructs2s_audio(args.audio_root, args.pairs_file)) | |
| # Deduplicate while preserving order | |
| seen: Set[Path] = set() | |
| unique_paths: List[Path] = [] | |
| for path in paths: | |
| if path not in seen: | |
| seen.add(path) | |
| unique_paths.append(path) | |
| return unique_paths | |
| def split_into_shards(items: List[Path], shard_count: int) -> List[List[Path]]: # pragma: no cover - simple helper | |
| shard_count = max(1, shard_count) | |
| shard_size = (len(items) + shard_count - 1) // shard_count | |
| return [items[i * shard_size : (i + 1) * shard_size] for i in range(shard_count)] | |
| def process_shard( | |
| shard_id: int, | |
| audio_paths: List[Path], | |
| device: str, | |
| tokenizer_name: str, | |
| output_root: Path, | |
| overwrite: bool, | |
| dataset_type: str, | |
| ) -> tuple[int, int, List[Path]]: | |
| if not audio_paths: | |
| return 0, 0, [] | |
| device_obj = torch.device(device) | |
| if device_obj.type == "cuda": | |
| torch.cuda.set_device(device_obj) | |
| tokenizer = EMOVASpeechTokenizer.from_pretrained(tokenizer_name).to(device_obj) | |
| tokenizer.eval() | |
| total = 0 | |
| skipped = 0 | |
| desc = f"{dataset_type} worker {shard_id}" | |
| failed_paths: List[Path] = [] | |
| for audio_path in tqdm(audio_paths, desc=desc, position=shard_id, leave=False): | |
| out_path = token_output_path(output_root, audio_path) | |
| if out_path.exists() and not overwrite: | |
| skipped += 1 | |
| continue | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| try: | |
| tokens = encode_audio(tokenizer, audio_path) | |
| except Exception as exc: # pragma: no cover - runtime diagnostics | |
| tqdm.write(f"[WARN][worker {shard_id}] Failed to encode {audio_path}: {exc}") | |
| failed_paths.append(audio_path) | |
| continue | |
| tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") | |
| torch.save(tokens, tmp_path) | |
| os.replace(tmp_path, out_path) | |
| total += 1 | |
| return total, skipped, failed_paths | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Pre-compute speech tokens for audio datasets.") | |
| parser.add_argument( | |
| "--dataset-type", | |
| "--dataset_type", | |
| dest="dataset_type", | |
| choices=["video-speech", "librispeech", "instructs2s"], | |
| default="video-speech", | |
| help="Dataset type to process.", | |
| ) | |
| parser.add_argument( | |
| "--index", | |
| type=Path, | |
| help="CSV index for video-speech datasets (required for dataset-type=video-speech).", | |
| ) | |
| parser.add_argument( | |
| "--audio-root", | |
| type=Path, | |
| required=True, | |
| help="Root directory containing audio files. For LibriSpeech this should be the LibriSpeech root.", | |
| ) | |
| parser.add_argument( | |
| "--librispeech_subsets", | |
| nargs="+", | |
| default=None, | |
| help="LibriSpeech subsets to process (e.g., train-clean-360). Required when dataset-type=librispeech.", | |
| ) | |
| parser.add_argument( | |
| "--pairs-file", | |
| "--pairs_file", | |
| type=Path, | |
| default=None, | |
| help="Optional pairs.txt to use for instructs2s dataset.", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| required=True, | |
| help="Directory to store the precomputed token tensors.", | |
| ) | |
| parser.add_argument( | |
| "--tokenizer", | |
| type=str, | |
| default="Emova-ollm/emova_speech_tokenizer_hf", | |
| help="Name or path of the EMOVA speech tokenizer checkpoint to use.", | |
| ) | |
| parser.add_argument( | |
| "--overwrite", | |
| action="store_true", | |
| help="Recompute and overwrite existing token files.", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default="cuda" if torch.cuda.is_available() else "cpu", | |
| help="Device for running the tokenizer encoder.", | |
| ) | |
| parser.add_argument( | |
| "--devices", | |
| nargs="+", | |
| default=None, | |
| help="Optional list of devices per worker (e.g., cuda:0 cuda:1 ...). Overrides --device/--num-workers.", | |
| ) | |
| parser.add_argument( | |
| "--num-workers", | |
| type=int, | |
| default=1, | |
| help="Number of parallel workers (ignored if --devices is provided).", | |
| ) | |
| args = parser.parse_args() | |
| if args.index is not None: | |
| args.index = args.index.expanduser().resolve() | |
| if not args.index.exists(): | |
| parser.error(f"Index file not found: {args.index}") | |
| if args.dataset_type == "video-speech" and args.index is None: | |
| parser.error("--index is required when dataset-type=video-speech.") | |
| if args.dataset_type == "librispeech" and not args.librispeech_subsets: | |
| parser.error("--librispeech-subsets must be provided when dataset-type=librispeech.") | |
| args.audio_root = args.audio_root.expanduser().resolve() | |
| args.output = args.output.expanduser().resolve() | |
| if args.pairs_file is not None: | |
| args.pairs_file = Path(args.pairs_file).expanduser().resolve() | |
| if not args.pairs_file.exists(): | |
| parser.error(f"pairs-file not found: {args.pairs_file}") | |
| args.output.mkdir(parents=True, exist_ok=True) | |
| audio_paths = gather_audio_paths(args) | |
| if not audio_paths: | |
| print("No audio files found. Nothing to encode.") | |
| return | |
| if args.devices: | |
| worker_devices = args.devices | |
| else: | |
| worker_devices = [args.device] * max(1, args.num_workers) | |
| if len(worker_devices) == 1: | |
| device = torch.device(worker_devices[0]) | |
| tokenizer = EMOVASpeechTokenizer.from_pretrained(args.tokenizer).to(device) | |
| tokenizer.eval() | |
| total = 0 | |
| skipped = 0 | |
| failed_paths: List[Path] = [] | |
| for audio_path in tqdm(audio_paths, desc="Encoding clips"): | |
| out_path = token_output_path(args.output, audio_path) | |
| if out_path.exists() and not args.overwrite: | |
| skipped += 1 | |
| continue | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| try: | |
| tokens = encode_audio(tokenizer, audio_path) | |
| except Exception as exc: | |
| tqdm.write(f"[WARN] Failed to encode {audio_path}: {exc}") | |
| failed_paths.append(audio_path) | |
| continue | |
| tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") | |
| torch.save(tokens, tmp_path) | |
| os.replace(tmp_path, out_path) | |
| total += 1 | |
| if failed_paths: | |
| failed_log = args.output / "failed_paths.log" | |
| with failed_log.open("a") as fh: | |
| for path in failed_paths: | |
| fh.write(f"{path}\n") | |
| print(f"Wrote {len(failed_paths)} failed paths to {failed_log}") | |
| print(f"Done. Encoded {total} clips. Skipped {skipped} existing entries.") | |
| return | |
| shards = split_into_shards(audio_paths, len(worker_devices)) | |
| from multiprocessing import get_context | |
| ctx = get_context("spawn") | |
| futures = [] | |
| with ctx.Pool(len(worker_devices)) as pool: | |
| for shard_id, (device_str, shard_paths) in enumerate(zip(worker_devices, shards)): | |
| futures.append( | |
| pool.apply_async( | |
| process_shard, | |
| ( | |
| shard_id, | |
| shard_paths, | |
| device_str, | |
| args.tokenizer, | |
| args.output, | |
| args.overwrite, | |
| args.dataset_type, | |
| ), | |
| ) | |
| ) | |
| pool.close() | |
| pool.join() | |
| total = 0 | |
| skipped = 0 | |
| failed_paths: List[Path] = [] | |
| for fut in futures: | |
| shard_total, shard_skipped, shard_failed = fut.get() | |
| total += shard_total | |
| skipped += shard_skipped | |
| failed_paths.extend(shard_failed) | |
| if failed_paths: | |
| failed_log = args.output / "failed_paths.log" | |
| with failed_log.open("a") as fh: | |
| for path in failed_paths: | |
| fh.write(f"{path}\n") | |
| print(f"Wrote {len(failed_paths)} failed paths to {failed_log}") | |
| print(f"Done. Encoded {total} clips. Skipped {skipped} existing entries.") | |
| if __name__ == "__main__": | |
| main() | |