Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| Pre-compute EMOVA speech tokenizer codes for InstructS2S (๋๋ ๊ธฐํ ๋จ์ผ ์ค๋์ค ํด๋). | |
| ์์: | |
| python /home/work/AIDAS/MMaDA/precompute_instructs2s_tokens.py \ | |
| --audio-root /home/work/AIDAS/data/InstructS2S-200K/en/wav \ | |
| --output-root /home/work/AIDAS/cache/instructs2s_tokens \ | |
| --pairs-file /home/work/AIDAS/data/InstructS2S-200K/en/wav/pairs.txt | |
| sha1(์ ๋๊ฒฝ๋ก) ๊ธฐ๋ฐ ์บ์ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๋ฏ๋ก, ํ์ต ์ฝ๋์์ ๊ธฐ๋ํ๋ ๋๋ ํฐ๋ฆฌ | |
| (`MixedSpeechTextDataset`, `Speech2SpeechDataset`)์ ๋์ผํ๊ฒ ๋์ํฉ๋๋ค. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import hashlib | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Iterable, Iterator, Optional, Sequence | |
| import soundfile as sf | |
| import torch | |
| from tqdm import tqdm | |
| # Ensure project root on path | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| if str(REPO_ROOT) not in os.sys.path: | |
| os.sys.path.append(str(REPO_ROOT)) | |
| from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer # noqa: E402 | |
| def iter_instructs2s_audio( | |
| audio_root: Path, pairs_file: Optional[Path] = None | |
| ) -> Iterator[Path]: | |
| """ | |
| InstructS2S ๋ฃจํธ๋ฅผ ์ํํ๋ฉฐ user/assistant wav ๊ฒฝ๋ก๋ฅผ ๋ชจ๋ yield ํฉ๋๋ค. | |
| `pairs.txt`๊ฐ ์ ๊ณต๋๋ฉด ๊ฐ์ฅ ์ฐ์ ์ผ๋ก ์ฌ์ฉํ๊ณ , ์์ผ๋ฉด ๋๋ ํฐ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ์ํํฉ๋๋ค. | |
| """ | |
| resolved_root = audio_root.expanduser().resolve() | |
| pairs_candidate = pairs_file | |
| if pairs_candidate is None: | |
| candidate = resolved_root / "pairs.txt" | |
| if candidate.exists(): | |
| pairs_candidate = candidate | |
| if pairs_candidate is not None: | |
| with pairs_candidate.open("r") as fh: | |
| for line in fh: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| parts = line.split() | |
| if len(parts) < 2: | |
| continue | |
| user_path = Path(parts[0]) | |
| if not user_path.is_absolute(): | |
| user_path = resolved_root / user_path | |
| assistant_path = Path(parts[1]) | |
| if not assistant_path.is_absolute(): | |
| assistant_path = resolved_root / assistant_path | |
| if user_path.is_file(): | |
| yield user_path | |
| if assistant_path.is_file(): | |
| yield assistant_path | |
| return | |
| # pairs.txt๊ฐ ์์ผ๋ฉด ๋๋ ํฐ๋ฆฌ ์ํ | |
| for dir_path in sorted(resolved_root.iterdir()): | |
| if not dir_path.is_dir(): | |
| continue | |
| dir_name = dir_path.name | |
| k = 1 | |
| while True: | |
| user_wav = dir_path / f"{dir_name}-{k}-user.wav" | |
| assistant_wav = dir_path / f"{dir_name}-{k}-assistant.wav" | |
| if user_wav.is_file() and assistant_wav.is_file(): | |
| yield user_wav | |
| yield assistant_wav | |
| k += 1 | |
| continue | |
| break | |
| def hash_path(path: Path) -> str: | |
| """์ ๋ ๊ฒฝ๋ก๋ฅผ sha1์ผ๋ก ํด์ํ 40๊ธ์ hex ๋ฐํ.""" | |
| abs_path = os.path.abspath(path) | |
| return hashlib.sha1(abs_path.encode("utf-8")).hexdigest() | |
| def token_output_path(output_root: Path, audio_path: Path) -> Path: | |
| digest = hash_path(audio_path) | |
| return output_root / digest[:2] / digest[2:4] / f"{digest}.pt" | |
| def encode_audio(tokenizer: EMOVASpeechTokenizer, audio_path: Path) -> torch.Tensor: | |
| """ | |
| EMOVA ํ ํฌ๋์ด์ ๋ก ์ค๋์ค๋ฅผ ํ ํฐํ. | |
| ๋น-WAV ํฌ๋งท์ ์์ ํ์ผ๋ก ๋ณํ ํ ์ฒ๋ฆฌํฉ๋๋ค. | |
| """ | |
| suffix = audio_path.suffix.lower() | |
| if suffix == ".wav": | |
| return tokenizer.encode(str(audio_path)).cpu() | |
| data, sample_rate = sf.read(str(audio_path)) | |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| try: | |
| sf.write(tmp.name, data, sample_rate) | |
| tokens = tokenizer.encode(tmp.name).cpu() | |
| finally: | |
| tmp.close() | |
| try: | |
| os.remove(tmp.name) | |
| except OSError: | |
| pass | |
| return tokens | |
| def gather_audio_paths(audio_root: Path, pairs_file: Optional[Path]) -> list[Path]: | |
| paths = list(iter_instructs2s_audio(audio_root, pairs_file)) | |
| # ์ค๋ณต ์ ๊ฑฐ | |
| seen = set() | |
| unique: list[Path] = [] | |
| for path in paths: | |
| if path not in seen: | |
| seen.add(path) | |
| unique.append(path) | |
| return unique | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Pre-compute EMOVA speech tokens for InstructS2S.") | |
| parser.add_argument( | |
| "--audio-root", | |
| type=Path, | |
| default=Path("/home/work/AIDAS/data/InstructS2S-200K/en/wav"), | |
| help="user/assistant WAV๊ฐ ์์นํ ๋ฃจํธ ๋๋ ํฐ๋ฆฌ", | |
| ) | |
| parser.add_argument( | |
| "--pairs-file", | |
| type=Path, | |
| default=None, | |
| help="์ ํ ์ฌํญ: pairs.txt ๊ฒฝ๋ก (์ง์ ํ์ง ์์ผ๋ฉด audio-root/pairs.txt ํ์)", | |
| ) | |
| parser.add_argument( | |
| "--output-root", | |
| type=Path, | |
| default=Path("/home/work/AIDAS/cache/instructs2s_tokens"), | |
| help="ํ ํฐ์ ์ ์ฅํ ๋๋ ํฐ๋ฆฌ", | |
| ) | |
| parser.add_argument( | |
| "--tokenizer", | |
| type=str, | |
| default="Emova-ollm/emova_speech_tokenizer_hf", | |
| help="EMOVA speech tokenizer ์ฒดํฌํฌ์ธํธ", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default="cuda" if torch.cuda.is_available() else "cpu", | |
| help="์ธ์ฝ๋ฉ์ ์ฌ์ฉํ ๋๋ฐ์ด์ค", | |
| ) | |
| parser.add_argument( | |
| "--overwrite", | |
| action="store_true", | |
| help="์ด๋ฏธ ์กด์ฌํ๋ ํ ํฐ์ ๋ค์ ๊ณ์ฐํฉ๋๋ค.", | |
| ) | |
| args = parser.parse_args() | |
| audio_root = args.audio_root.expanduser().resolve() | |
| if not audio_root.exists(): | |
| parser.error(f"Audio root not found: {audio_root}") | |
| pairs_file = args.pairs_file.expanduser().resolve() if args.pairs_file else None | |
| if pairs_file is not None and not pairs_file.exists(): | |
| parser.error(f"pairs-file not found: {pairs_file}") | |
| output_root = args.output_root.expanduser().resolve() | |
| output_root.mkdir(parents=True, exist_ok=True) | |
| audio_paths = gather_audio_paths(audio_root, pairs_file) | |
| if not audio_paths: | |
| print("No audio files found. Nothing to encode.") | |
| return | |
| device = torch.device(args.device) | |
| if device.type == "cuda": | |
| torch.cuda.set_device(device) | |
| tokenizer = EMOVASpeechTokenizer.from_pretrained(args.tokenizer).to(device) | |
| tokenizer.eval() | |
| total = 0 | |
| skipped = 0 | |
| failed: list[Path] = [] | |
| for audio_path in tqdm(audio_paths, desc="Encoding InstructS2S clips"): | |
| audio_path = audio_path.expanduser().resolve() | |
| out_path = token_output_path(output_root, audio_path) | |
| if out_path.exists() and not args.overwrite: | |
| skipped += 1 | |
| continue | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| try: | |
| tokens = encode_audio(tokenizer, audio_path) | |
| except Exception as exc: | |
| tqdm.write(f"[WARN] Failed to encode {audio_path}: {exc}") | |
| failed.append(audio_path) | |
| continue | |
| tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") | |
| torch.save(tokens, tmp_path) | |
| os.replace(tmp_path, out_path) | |
| total += 1 | |
| if failed: | |
| failed_log = output_root / "failed_paths.log" | |
| with failed_log.open("a") as fh: | |
| for path in failed: | |
| fh.write(f"{path}\n") | |
| print(f"Failed to encode {len(failed)} files. See {failed_log}") | |
| print(f"Done. Encoded {total} files. Skipped {skipped} existing entries.") | |
| if __name__ == "__main__": | |
| main() | |