File size: 7,716 Bytes
7bfbdc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#!/usr/bin/env python3
"""
Pre-compute EMOVA speech tokenizer codes for InstructS2S (๋˜๋Š” ๊ธฐํƒ€ ๋‹จ์ผ ์˜ค๋””์˜ค ํด๋”).

์˜ˆ์‹œ:
    python /home/work/AIDAS/MMaDA/precompute_instructs2s_tokens.py \
        --audio-root /home/work/AIDAS/data/InstructS2S-200K/en/wav \
        --output-root /home/work/AIDAS/cache/instructs2s_tokens \
        --pairs-file /home/work/AIDAS/data/InstructS2S-200K/en/wav/pairs.txt

sha1(์ ˆ๋Œ€๊ฒฝ๋กœ) ๊ธฐ๋ฐ˜ ์บ์‹œ ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ, ํ•™์Šต ์ฝ”๋“œ์—์„œ ๊ธฐ๋Œ€ํ•˜๋Š” ๋””๋ ‰ํ„ฐ๋ฆฌ
(`MixedSpeechTextDataset`, `Speech2SpeechDataset`)์™€ ๋™์ผํ•˜๊ฒŒ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.
"""

from __future__ import annotations

import argparse
import hashlib
import os
import tempfile
from pathlib import Path
from typing import Iterable, Iterator, Optional, Sequence

import soundfile as sf
import torch
from tqdm import tqdm

# Ensure project root on path
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in os.sys.path:
    os.sys.path.append(str(REPO_ROOT))

from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer  # noqa: E402


def iter_instructs2s_audio(
    audio_root: Path, pairs_file: Optional[Path] = None
) -> Iterator[Path]:
    """
    InstructS2S ๋ฃจํŠธ๋ฅผ ์ˆœํšŒํ•˜๋ฉฐ user/assistant wav ๊ฒฝ๋กœ๋ฅผ ๋ชจ๋‘ yield ํ•ฉ๋‹ˆ๋‹ค.

    `pairs.txt`๊ฐ€ ์ œ๊ณต๋˜๋ฉด ๊ฐ€์žฅ ์šฐ์„ ์œผ๋กœ ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด ๋””๋ ‰ํ„ฐ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ์ˆœํšŒํ•ฉ๋‹ˆ๋‹ค.
    """
    resolved_root = audio_root.expanduser().resolve()
    pairs_candidate = pairs_file
    if pairs_candidate is None:
        candidate = resolved_root / "pairs.txt"
        if candidate.exists():
            pairs_candidate = candidate

    if pairs_candidate is not None:
        with pairs_candidate.open("r") as fh:
            for line in fh:
                line = line.strip()
                if not line:
                    continue
                parts = line.split()
                if len(parts) < 2:
                    continue
                user_path = Path(parts[0])
                if not user_path.is_absolute():
                    user_path = resolved_root / user_path
                assistant_path = Path(parts[1])
                if not assistant_path.is_absolute():
                    assistant_path = resolved_root / assistant_path
                if user_path.is_file():
                    yield user_path
                if assistant_path.is_file():
                    yield assistant_path
        return

    # pairs.txt๊ฐ€ ์—†์œผ๋ฉด ๋””๋ ‰ํ„ฐ๋ฆฌ ์ˆœํšŒ
    for dir_path in sorted(resolved_root.iterdir()):
        if not dir_path.is_dir():
            continue
        dir_name = dir_path.name
        k = 1
        while True:
            user_wav = dir_path / f"{dir_name}-{k}-user.wav"
            assistant_wav = dir_path / f"{dir_name}-{k}-assistant.wav"
            if user_wav.is_file() and assistant_wav.is_file():
                yield user_wav
                yield assistant_wav
                k += 1
                continue
            break


def hash_path(path: Path) -> str:
    """์ ˆ๋Œ€ ๊ฒฝ๋กœ๋ฅผ sha1์œผ๋กœ ํ•ด์‹œํ•œ 40๊ธ€์ž hex ๋ฐ˜ํ™˜."""
    abs_path = os.path.abspath(path)
    return hashlib.sha1(abs_path.encode("utf-8")).hexdigest()


def token_output_path(output_root: Path, audio_path: Path) -> Path:
    digest = hash_path(audio_path)
    return output_root / digest[:2] / digest[2:4] / f"{digest}.pt"


def encode_audio(tokenizer: EMOVASpeechTokenizer, audio_path: Path) -> torch.Tensor:
    """
    EMOVA ํ† ํฌ๋‚˜์ด์ €๋กœ ์˜ค๋””์˜ค๋ฅผ ํ† ํฐํ™”.
    ๋น„-WAV ํฌ๋งท์€ ์ž„์‹œ ํŒŒ์ผ๋กœ ๋ณ€ํ™˜ ํ›„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
    """
    suffix = audio_path.suffix.lower()
    if suffix == ".wav":
        return tokenizer.encode(str(audio_path)).cpu()

    data, sample_rate = sf.read(str(audio_path))
    tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    try:
        sf.write(tmp.name, data, sample_rate)
        tokens = tokenizer.encode(tmp.name).cpu()
    finally:
        tmp.close()
        try:
            os.remove(tmp.name)
        except OSError:
            pass
    return tokens


def gather_audio_paths(audio_root: Path, pairs_file: Optional[Path]) -> list[Path]:
    paths = list(iter_instructs2s_audio(audio_root, pairs_file))
    # ์ค‘๋ณต ์ œ๊ฑฐ
    seen = set()
    unique: list[Path] = []
    for path in paths:
        if path not in seen:
            seen.add(path)
            unique.append(path)
    return unique


def main():
    parser = argparse.ArgumentParser(description="Pre-compute EMOVA speech tokens for InstructS2S.")
    parser.add_argument(
        "--audio-root",
        type=Path,
        default=Path("/home/work/AIDAS/data/InstructS2S-200K/en/wav"),
        help="user/assistant WAV๊ฐ€ ์œ„์น˜ํ•œ ๋ฃจํŠธ ๋””๋ ‰ํ„ฐ๋ฆฌ",
    )
    parser.add_argument(
        "--pairs-file",
        type=Path,
        default=None,
        help="์„ ํƒ ์‚ฌํ•ญ: pairs.txt ๊ฒฝ๋กœ (์ง€์ •ํ•˜์ง€ ์•Š์œผ๋ฉด audio-root/pairs.txt ํƒ์ƒ‰)",
    )
    parser.add_argument(
        "--output-root",
        type=Path,
        default=Path("/home/work/AIDAS/cache/instructs2s_tokens"),
        help="ํ† ํฐ์„ ์ €์žฅํ•  ๋””๋ ‰ํ„ฐ๋ฆฌ",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        default="Emova-ollm/emova_speech_tokenizer_hf",
        help="EMOVA speech tokenizer ์ฒดํฌํฌ์ธํŠธ",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="์ธ์ฝ”๋”ฉ์— ์‚ฌ์šฉํ•  ๋””๋ฐ”์ด์Šค",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="์ด๋ฏธ ์กด์žฌํ•˜๋Š” ํ† ํฐ์„ ๋‹ค์‹œ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.",
    )
    args = parser.parse_args()

    audio_root = args.audio_root.expanduser().resolve()
    if not audio_root.exists():
        parser.error(f"Audio root not found: {audio_root}")

    pairs_file = args.pairs_file.expanduser().resolve() if args.pairs_file else None
    if pairs_file is not None and not pairs_file.exists():
        parser.error(f"pairs-file not found: {pairs_file}")

    output_root = args.output_root.expanduser().resolve()
    output_root.mkdir(parents=True, exist_ok=True)

    audio_paths = gather_audio_paths(audio_root, pairs_file)
    if not audio_paths:
        print("No audio files found. Nothing to encode.")
        return

    device = torch.device(args.device)
    if device.type == "cuda":
        torch.cuda.set_device(device)
    tokenizer = EMOVASpeechTokenizer.from_pretrained(args.tokenizer).to(device)
    tokenizer.eval()

    total = 0
    skipped = 0
    failed: list[Path] = []

    for audio_path in tqdm(audio_paths, desc="Encoding InstructS2S clips"):
        audio_path = audio_path.expanduser().resolve()
        out_path = token_output_path(output_root, audio_path)
        if out_path.exists() and not args.overwrite:
            skipped += 1
            continue

        out_path.parent.mkdir(parents=True, exist_ok=True)
        try:
            tokens = encode_audio(tokenizer, audio_path)
        except Exception as exc:
            tqdm.write(f"[WARN] Failed to encode {audio_path}: {exc}")
            failed.append(audio_path)
            continue

        tmp_path = out_path.with_suffix(out_path.suffix + ".tmp")
        torch.save(tokens, tmp_path)
        os.replace(tmp_path, out_path)
        total += 1

    if failed:
        failed_log = output_root / "failed_paths.log"
        with failed_log.open("a") as fh:
            for path in failed:
                fh.write(f"{path}\n")
        print(f"Failed to encode {len(failed)} files. See {failed_log}")

    print(f"Done. Encoded {total} files. Skipped {skipped} existing entries.")


if __name__ == "__main__":
    main()