AIDAS-Omni-Modal-Diffusion / MMaDA /precompute_instructs2s_tokens.py
jaeikkim
Reinit Space without binary assets
7bfbdc3
#!/usr/bin/env python3
"""
Pre-compute EMOVA speech tokenizer codes for InstructS2S (๋˜๋Š” ๊ธฐํƒ€ ๋‹จ์ผ ์˜ค๋””์˜ค ํด๋”).
์˜ˆ์‹œ:
python /home/work/AIDAS/MMaDA/precompute_instructs2s_tokens.py \
--audio-root /home/work/AIDAS/data/InstructS2S-200K/en/wav \
--output-root /home/work/AIDAS/cache/instructs2s_tokens \
--pairs-file /home/work/AIDAS/data/InstructS2S-200K/en/wav/pairs.txt
sha1(์ ˆ๋Œ€๊ฒฝ๋กœ) ๊ธฐ๋ฐ˜ ์บ์‹œ ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ, ํ•™์Šต ์ฝ”๋“œ์—์„œ ๊ธฐ๋Œ€ํ•˜๋Š” ๋””๋ ‰ํ„ฐ๋ฆฌ
(`MixedSpeechTextDataset`, `Speech2SpeechDataset`)์™€ ๋™์ผํ•˜๊ฒŒ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.
"""
from __future__ import annotations
import argparse
import hashlib
import os
import tempfile
from pathlib import Path
from typing import Iterable, Iterator, Optional, Sequence
import soundfile as sf
import torch
from tqdm import tqdm
# Ensure project root on path
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in os.sys.path:
os.sys.path.append(str(REPO_ROOT))
from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer # noqa: E402
def iter_instructs2s_audio(
audio_root: Path, pairs_file: Optional[Path] = None
) -> Iterator[Path]:
"""
InstructS2S ๋ฃจํŠธ๋ฅผ ์ˆœํšŒํ•˜๋ฉฐ user/assistant wav ๊ฒฝ๋กœ๋ฅผ ๋ชจ๋‘ yield ํ•ฉ๋‹ˆ๋‹ค.
`pairs.txt`๊ฐ€ ์ œ๊ณต๋˜๋ฉด ๊ฐ€์žฅ ์šฐ์„ ์œผ๋กœ ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด ๋””๋ ‰ํ„ฐ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ์ˆœํšŒํ•ฉ๋‹ˆ๋‹ค.
"""
resolved_root = audio_root.expanduser().resolve()
pairs_candidate = pairs_file
if pairs_candidate is None:
candidate = resolved_root / "pairs.txt"
if candidate.exists():
pairs_candidate = candidate
if pairs_candidate is not None:
with pairs_candidate.open("r") as fh:
for line in fh:
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) < 2:
continue
user_path = Path(parts[0])
if not user_path.is_absolute():
user_path = resolved_root / user_path
assistant_path = Path(parts[1])
if not assistant_path.is_absolute():
assistant_path = resolved_root / assistant_path
if user_path.is_file():
yield user_path
if assistant_path.is_file():
yield assistant_path
return
# pairs.txt๊ฐ€ ์—†์œผ๋ฉด ๋””๋ ‰ํ„ฐ๋ฆฌ ์ˆœํšŒ
for dir_path in sorted(resolved_root.iterdir()):
if not dir_path.is_dir():
continue
dir_name = dir_path.name
k = 1
while True:
user_wav = dir_path / f"{dir_name}-{k}-user.wav"
assistant_wav = dir_path / f"{dir_name}-{k}-assistant.wav"
if user_wav.is_file() and assistant_wav.is_file():
yield user_wav
yield assistant_wav
k += 1
continue
break
def hash_path(path: Path) -> str:
"""์ ˆ๋Œ€ ๊ฒฝ๋กœ๋ฅผ sha1์œผ๋กœ ํ•ด์‹œํ•œ 40๊ธ€์ž hex ๋ฐ˜ํ™˜."""
abs_path = os.path.abspath(path)
return hashlib.sha1(abs_path.encode("utf-8")).hexdigest()
def token_output_path(output_root: Path, audio_path: Path) -> Path:
digest = hash_path(audio_path)
return output_root / digest[:2] / digest[2:4] / f"{digest}.pt"
def encode_audio(tokenizer: EMOVASpeechTokenizer, audio_path: Path) -> torch.Tensor:
"""
EMOVA ํ† ํฌ๋‚˜์ด์ €๋กœ ์˜ค๋””์˜ค๋ฅผ ํ† ํฐํ™”.
๋น„-WAV ํฌ๋งท์€ ์ž„์‹œ ํŒŒ์ผ๋กœ ๋ณ€ํ™˜ ํ›„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
"""
suffix = audio_path.suffix.lower()
if suffix == ".wav":
return tokenizer.encode(str(audio_path)).cpu()
data, sample_rate = sf.read(str(audio_path))
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
try:
sf.write(tmp.name, data, sample_rate)
tokens = tokenizer.encode(tmp.name).cpu()
finally:
tmp.close()
try:
os.remove(tmp.name)
except OSError:
pass
return tokens
def gather_audio_paths(audio_root: Path, pairs_file: Optional[Path]) -> list[Path]:
paths = list(iter_instructs2s_audio(audio_root, pairs_file))
# ์ค‘๋ณต ์ œ๊ฑฐ
seen = set()
unique: list[Path] = []
for path in paths:
if path not in seen:
seen.add(path)
unique.append(path)
return unique
def main():
parser = argparse.ArgumentParser(description="Pre-compute EMOVA speech tokens for InstructS2S.")
parser.add_argument(
"--audio-root",
type=Path,
default=Path("/home/work/AIDAS/data/InstructS2S-200K/en/wav"),
help="user/assistant WAV๊ฐ€ ์œ„์น˜ํ•œ ๋ฃจํŠธ ๋””๋ ‰ํ„ฐ๋ฆฌ",
)
parser.add_argument(
"--pairs-file",
type=Path,
default=None,
help="์„ ํƒ ์‚ฌํ•ญ: pairs.txt ๊ฒฝ๋กœ (์ง€์ •ํ•˜์ง€ ์•Š์œผ๋ฉด audio-root/pairs.txt ํƒ์ƒ‰)",
)
parser.add_argument(
"--output-root",
type=Path,
default=Path("/home/work/AIDAS/cache/instructs2s_tokens"),
help="ํ† ํฐ์„ ์ €์žฅํ•  ๋””๋ ‰ํ„ฐ๋ฆฌ",
)
parser.add_argument(
"--tokenizer",
type=str,
default="Emova-ollm/emova_speech_tokenizer_hf",
help="EMOVA speech tokenizer ์ฒดํฌํฌ์ธํŠธ",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="์ธ์ฝ”๋”ฉ์— ์‚ฌ์šฉํ•  ๋””๋ฐ”์ด์Šค",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="์ด๋ฏธ ์กด์žฌํ•˜๋Š” ํ† ํฐ์„ ๋‹ค์‹œ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.",
)
args = parser.parse_args()
audio_root = args.audio_root.expanduser().resolve()
if not audio_root.exists():
parser.error(f"Audio root not found: {audio_root}")
pairs_file = args.pairs_file.expanduser().resolve() if args.pairs_file else None
if pairs_file is not None and not pairs_file.exists():
parser.error(f"pairs-file not found: {pairs_file}")
output_root = args.output_root.expanduser().resolve()
output_root.mkdir(parents=True, exist_ok=True)
audio_paths = gather_audio_paths(audio_root, pairs_file)
if not audio_paths:
print("No audio files found. Nothing to encode.")
return
device = torch.device(args.device)
if device.type == "cuda":
torch.cuda.set_device(device)
tokenizer = EMOVASpeechTokenizer.from_pretrained(args.tokenizer).to(device)
tokenizer.eval()
total = 0
skipped = 0
failed: list[Path] = []
for audio_path in tqdm(audio_paths, desc="Encoding InstructS2S clips"):
audio_path = audio_path.expanduser().resolve()
out_path = token_output_path(output_root, audio_path)
if out_path.exists() and not args.overwrite:
skipped += 1
continue
out_path.parent.mkdir(parents=True, exist_ok=True)
try:
tokens = encode_audio(tokenizer, audio_path)
except Exception as exc:
tqdm.write(f"[WARN] Failed to encode {audio_path}: {exc}")
failed.append(audio_path)
continue
tmp_path = out_path.with_suffix(out_path.suffix + ".tmp")
torch.save(tokens, tmp_path)
os.replace(tmp_path, out_path)
total += 1
if failed:
failed_log = output_root / "failed_paths.log"
with failed_log.open("a") as fh:
for path in failed:
fh.write(f"{path}\n")
print(f"Failed to encode {len(failed)} files. See {failed_log}")
print(f"Done. Encoded {total} files. Skipped {skipped} existing entries.")
if __name__ == "__main__":
main()