Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| 체크 방법 | |
| ========= | |
| python check_audio_tokens.py \ | |
| --config configs/omada_instruction_tuning.yaml \ | |
| --samples 20 | |
| """ | |
| import argparse | |
| import random | |
| from pathlib import Path | |
| from typing import Iterable, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from omegaconf import OmegaConf | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer | |
| from training.data import MixedSpeechTextDataset, VideoSpeechDataset | |
| from training.prompting_utils import UniversalPrompting | |
| from training.utils import image_transform | |
| import sys, os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| def _to_tensor(entry: Union[torch.Tensor, np.ndarray, list, tuple, str], | |
| vq_model: EMOVASpeechTokenizer) -> torch.Tensor: | |
| """entry가 경로면 encode, 이미 토큰이면 long tensor로 변환.""" | |
| if isinstance(entry, torch.Tensor): | |
| tokens = entry.clone().long() | |
| elif isinstance(entry, np.ndarray): | |
| tokens = torch.from_numpy(entry).long() | |
| elif isinstance(entry, (list, tuple)): | |
| tokens = torch.as_tensor(entry, dtype=torch.long) | |
| elif isinstance(entry, str): | |
| # EMOVA encode는 (1, L) 반환 → 1D로 변환 | |
| tokens = vq_model.encode(entry).squeeze(0).long() | |
| else: | |
| raise TypeError(f"Unsupported token entry type: {type(entry)}") | |
| return tokens.view(-1) | |
| def _log_stats(flow: str, path: str, tokens: torch.Tensor, | |
| codebook_size: int = 4096) -> Tuple[int, int]: | |
| max_id = int(tokens.max().item()) | |
| min_id = int(tokens.min().item()) | |
| over = int((tokens >= codebook_size).sum().item()) | |
| under = int((tokens < 0).sum().item()) | |
| print( | |
| f"[{flow}] path={path} " | |
| f"shape={tuple(tokens.shape)} " | |
| f"min={min_id} max={max_id} " | |
| f"<0={under} >=4096={over}" | |
| ) | |
| return over, under | |
| def build_prompting(config) -> UniversalPrompting: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| config.model.omada.tokenizer_path, | |
| padding_side="left", | |
| ) | |
| special_tokens = ( | |
| "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", | |
| "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", | |
| "<|i2i|>", "<|v2t|>", "<|v2s|>", "<|s2t|>", | |
| "<|t2s|>", "<|s2s|>", "<|soa|>", "<|eoa|>", | |
| ) | |
| prompt = UniversalPrompting( | |
| tokenizer, | |
| max_text_len=config.dataset.preprocessing.max_seq_length, | |
| max_audio_len=config.dataset.preprocessing.max_aud_length, | |
| max_audio_len_short=config.dataset.preprocessing.max_aud_length_short, | |
| ignore_id=-100, | |
| cond_dropout_prob=config.training.cond_dropout_prob, | |
| special_tokens=special_tokens, | |
| use_reserved_token=True, | |
| ) | |
| return prompt | |
| def sample_indices(length: int, num: int) -> Tuple[Iterable[int], int]: | |
| """ | |
| Returns iterable of indices and the total count that will be iterated. | |
| If num <= 0 or num >= length, iterates through the whole dataset. | |
| """ | |
| if num is None or num <= 0 or num >= length: | |
| return range(length), length | |
| indices = random.sample(range(length), num) | |
| return indices, len(indices) | |
| def inspect_v2s(config, prompting, vq_model, num_samples: int): | |
| speech_cfg = OmegaConf.to_container( | |
| config.dataset.params.get("video_speech_dataset", {}), | |
| resolve=True | |
| ) or {} | |
| dataset = VideoSpeechDataset( | |
| transform=image_transform, | |
| resolution=config.dataset.preprocessing.resolution, | |
| num_frames=speech_cfg.get("num_frames_speech", 4), | |
| video_root=speech_cfg.get( | |
| "video_root", "/home/work/AIDAS/data/video/openvid1m/video/video" | |
| ), | |
| audio_root=speech_cfg.get( | |
| "audio_root", "/home/work/AIDAS/data/video-speech" | |
| ), | |
| speech_dir_name=speech_cfg.get("speech_dir_name", "openvid-speech-trunc"), | |
| index_path=speech_cfg.get( | |
| "index_path", "/home/work/AIDAS/data/video-speech/openvid-speech.csv" | |
| ), | |
| sample_method=speech_cfg.get("sample_method", "uniform"), | |
| precomputed_tokens_root=speech_cfg.get("precomputed_tokens_root"), | |
| ) | |
| print(f"\n=== VideoSpeechDataset (v2s) | total={len(dataset)} ===") | |
| total_over = total_under = 0 | |
| indices, total = sample_indices(len(dataset), num_samples) | |
| for idx in tqdm(indices, total=total, desc="v2s audio", unit="sample"): | |
| sample = dataset.data[idx] | |
| speech_path = sample["speech"] | |
| tokens = dataset._load_precomputed_tokens(speech_path) | |
| if tokens is not None: | |
| tokens = tokens.long() | |
| else: | |
| tokens = vq_model.encode(speech_path).squeeze(0).long() | |
| over, under = _log_stats("v2s", speech_path, tokens) | |
| total_over += over | |
| total_under += under | |
| print(f"[v2s] total >=4096: {total_over} | total <0: {total_under}") | |
| def inspect_t2s(config, prompting, vq_model, num_samples: int): | |
| dataset = MixedSpeechTextDataset(config.dataset.params.audio_data) | |
| print(f"\n=== MixedSpeechTextDataset (t2s/s2t 공용) | total={len(dataset)} ===") | |
| total_over = total_under = 0 | |
| indices, total = sample_indices(len(dataset), num_samples) | |
| for idx in tqdm(indices, total=total, desc="t2s/s2t audio", unit="sample"): | |
| sample = dataset[idx] | |
| entry = sample["audio_path"] | |
| if isinstance(entry, np.ndarray): | |
| tokens = torch.from_numpy(entry).long() | |
| path_repr = "<precomputed-array>" | |
| elif isinstance(entry, str): | |
| tokens = vq_model.encode(entry).squeeze(0).long() | |
| path_repr = entry | |
| else: | |
| tokens = torch.as_tensor(entry, dtype=torch.long) | |
| path_repr = "<sequence>" | |
| over, under = _log_stats("t2s/s2t-source", path_repr, tokens) | |
| total_over += over | |
| total_under += under | |
| print(f"[t2s] total >=4096: {total_over} | total <0: {total_under}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True, | |
| help="학습에 사용한 YAML 설정 파일") | |
| parser.add_argument( | |
| "--samples", | |
| type=int, | |
| default=-1, | |
| help="각 데이터셋에서 검사할 샘플 수 (<=0이면 전체 검사)", | |
| ) | |
| args = parser.parse_args() | |
| config = OmegaConf.load(args.config) | |
| prompting = build_prompting(config) | |
| vq_model = EMOVASpeechTokenizer.from_pretrained( | |
| config.model.vq_model_audio.vq_model_name | |
| ) | |
| vq_model.eval() | |
| inspect_v2s(config, prompting, vq_model, args.samples) | |
| # inspect_t2s(config, prompting, vq_model, args.samples) | |
| if __name__ == "__main__": | |
| torch.manual_seed(0) | |
| random.seed(0) | |
| main() | |