File size: 6,808 Bytes
7bfbdc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#!/usr/bin/env python3
"""
체크 방법
=========
python check_audio_tokens.py \
  --config configs/omada_instruction_tuning.yaml \
  --samples 20
"""

import argparse
import random
from pathlib import Path
from typing import Iterable, Optional, Tuple, Union

import numpy as np
import torch
from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import AutoTokenizer

from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer
from training.data import MixedSpeechTextDataset, VideoSpeechDataset
from training.prompting_utils import UniversalPrompting
from training.utils import image_transform
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

def _to_tensor(entry: Union[torch.Tensor, np.ndarray, list, tuple, str],
               vq_model: EMOVASpeechTokenizer) -> torch.Tensor:
    """entry가 경로면 encode, 이미 토큰이면 long tensor로 변환."""
    if isinstance(entry, torch.Tensor):
        tokens = entry.clone().long()
    elif isinstance(entry, np.ndarray):
        tokens = torch.from_numpy(entry).long()
    elif isinstance(entry, (list, tuple)):
        tokens = torch.as_tensor(entry, dtype=torch.long)
    elif isinstance(entry, str):
        # EMOVA encode는 (1, L) 반환 → 1D로 변환
        tokens = vq_model.encode(entry).squeeze(0).long()
    else:
        raise TypeError(f"Unsupported token entry type: {type(entry)}")
    return tokens.view(-1)


def _log_stats(flow: str, path: str, tokens: torch.Tensor,
               codebook_size: int = 4096) -> Tuple[int, int]:
    max_id = int(tokens.max().item())
    min_id = int(tokens.min().item())
    over = int((tokens >= codebook_size).sum().item())
    under = int((tokens < 0).sum().item())

    print(
        f"[{flow}] path={path} "
        f"shape={tuple(tokens.shape)} "
        f"min={min_id} max={max_id} "
        f"<0={under} >=4096={over}"
    )
    return over, under


def build_prompting(config) -> UniversalPrompting:
    tokenizer = AutoTokenizer.from_pretrained(
        config.model.omada.tokenizer_path,
        padding_side="left",
    )
    special_tokens = (
        "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>",
        "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>",
        "<|i2i|>", "<|v2t|>", "<|v2s|>", "<|s2t|>",
        "<|t2s|>", "<|s2s|>", "<|soa|>", "<|eoa|>",
    )
    prompt = UniversalPrompting(
        tokenizer,
        max_text_len=config.dataset.preprocessing.max_seq_length,
        max_audio_len=config.dataset.preprocessing.max_aud_length,
        max_audio_len_short=config.dataset.preprocessing.max_aud_length_short,
        ignore_id=-100,
        cond_dropout_prob=config.training.cond_dropout_prob,
        special_tokens=special_tokens,
        use_reserved_token=True,
    )
    return prompt


def sample_indices(length: int, num: int) -> Tuple[Iterable[int], int]:
    """
    Returns iterable of indices and the total count that will be iterated.
    If num <= 0 or num >= length, iterates through the whole dataset.
    """
    if num is None or num <= 0 or num >= length:
        return range(length), length
    indices = random.sample(range(length), num)
    return indices, len(indices)


@torch.no_grad()
def inspect_v2s(config, prompting, vq_model, num_samples: int):
    speech_cfg = OmegaConf.to_container(
        config.dataset.params.get("video_speech_dataset", {}),
        resolve=True
    ) or {}
    dataset = VideoSpeechDataset(
        transform=image_transform,
        resolution=config.dataset.preprocessing.resolution,
        num_frames=speech_cfg.get("num_frames_speech", 4),
        video_root=speech_cfg.get(
            "video_root", "/home/work/AIDAS/data/video/openvid1m/video/video"
        ),
        audio_root=speech_cfg.get(
            "audio_root", "/home/work/AIDAS/data/video-speech"
        ),
        speech_dir_name=speech_cfg.get("speech_dir_name", "openvid-speech-trunc"),
        index_path=speech_cfg.get(
            "index_path", "/home/work/AIDAS/data/video-speech/openvid-speech.csv"
        ),
        sample_method=speech_cfg.get("sample_method", "uniform"),
        precomputed_tokens_root=speech_cfg.get("precomputed_tokens_root"),
    )

    print(f"\n=== VideoSpeechDataset (v2s) | total={len(dataset)} ===")
    total_over = total_under = 0
    indices, total = sample_indices(len(dataset), num_samples)
    for idx in tqdm(indices, total=total, desc="v2s audio", unit="sample"):
        sample = dataset.data[idx]
        speech_path = sample["speech"]
        tokens = dataset._load_precomputed_tokens(speech_path)
        if tokens is not None:
            tokens = tokens.long()
        else:
            tokens = vq_model.encode(speech_path).squeeze(0).long()
        over, under = _log_stats("v2s", speech_path, tokens)
        total_over += over
        total_under += under

    print(f"[v2s] total >=4096: {total_over} | total <0: {total_under}")


@torch.no_grad()
def inspect_t2s(config, prompting, vq_model, num_samples: int):
    dataset = MixedSpeechTextDataset(config.dataset.params.audio_data)

    print(f"\n=== MixedSpeechTextDataset (t2s/s2t 공용) | total={len(dataset)} ===")
    total_over = total_under = 0
    indices, total = sample_indices(len(dataset), num_samples)
    for idx in tqdm(indices, total=total, desc="t2s/s2t audio", unit="sample"):
        sample = dataset[idx]
        entry = sample["audio_path"]
        if isinstance(entry, np.ndarray):
            tokens = torch.from_numpy(entry).long()
            path_repr = "<precomputed-array>"
        elif isinstance(entry, str):
            tokens = vq_model.encode(entry).squeeze(0).long()
            path_repr = entry
        else:
            tokens = torch.as_tensor(entry, dtype=torch.long)
            path_repr = "<sequence>"
        over, under = _log_stats("t2s/s2t-source", path_repr, tokens)
        total_over += over
        total_under += under

    print(f"[t2s] total >=4096: {total_over} | total <0: {total_under}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True,
                        help="학습에 사용한 YAML 설정 파일")
    parser.add_argument(
        "--samples",
        type=int,
        default=-1,
        help="각 데이터셋에서 검사할 샘플 수 (<=0이면 전체 검사)",
    )
    args = parser.parse_args()

    config = OmegaConf.load(args.config)
    prompting = build_prompting(config)

    vq_model = EMOVASpeechTokenizer.from_pretrained(
        config.model.vq_model_audio.vq_model_name
    )
    vq_model.eval()

    inspect_v2s(config, prompting, vq_model, args.samples)
    # inspect_t2s(config, prompting, vq_model, args.samples)


if __name__ == "__main__":
    torch.manual_seed(0)
    random.seed(0)
    main()