AIDAS-Omni-Modal-Diffusion / MMaDA /inference_t2s_emova.py
jaeikkim
Reinit Space without binary assets
7bfbdc3
# coding=utf-8
# Copyright 2025 AIDAS Lab
import os
import random
import editdistance
from functools import partial
import re
import soundfile as sf
import numpy as np
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from tqdm import tqdm
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import wandb
from datasets import load_dataset
from models import OMadaModelLM
from training.data import T2S_INSTRUCTION # T2S_INSTRUCTION import
from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer
from training.prompting_utils import UniversalPrompting
from training.utils import get_config, flatten_omega_conf
from models import get_mask_schedule
from transformers import AutoTokenizer, pipeline
import argparse
import logging
# --- (setup_logger, calculate_WER, get_emova_dataset_tts, EMOVATtsEvalDataset, setup_distributed, cleanup_distributed ํ•จ์ˆ˜๋Š” ์ด์ „๊ณผ ๋™์ผ) ---
def setup_logger(rank):
logger = logging.getLogger(__name__)
if logger.hasHandlers():
logger.handlers.clear()
formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
if rank == 0:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.WARNING)
return logger
def calculate_WER(recognized_text_list, groundtruth_text_list):
word_num = 0.0
scores = 0.0
for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list):
recognized_text = recognized_text.lower(); groundtruth_text = groundtruth_text.lower()
recognized_text = re.sub(r"[^\w\s']", "", recognized_text); groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text)
recognized_word_list = recognized_text.split(); groundtruth_word_list = groundtruth_text.split()
current_word_num = len(groundtruth_word_list); word_num += current_word_num
current_score = editdistance.eval(recognized_word_list, groundtruth_word_list); scores += current_score
WER = scores / word_num if word_num > 0 else 0.0
return WER, scores, word_num
def get_emova_dataset_tts(logger):
logger.info("Loading EMOVA dataset (librispeech-asr-tts config) for TTS...")
dataset = load_dataset("Emova-ollm/emova-asr-tts-eval", "librispeech-asr-tts", split='test')
original_count = len(dataset)
dataset = dataset.filter(
lambda example: 'tts' in example['id'] and '<speech' not in example['conversations'][0]['value']
)
logger.info(f"Original dataset size: {original_count}. Filtered for TTS tasks. New size: {len(dataset)}")
return dataset
def setup_distributed(rank, world_size):
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup_distributed():
dist.destroy_process_group()
class EMOVATtsEvalDataset(Dataset):
def __init__(self, hf_dataset):
self.hf_dataset = hf_dataset
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
example = self.hf_dataset[idx]
gt_text = example['conversations'][0]['value']
sample_id = example['id']
return {"gt_text": gt_text, "sample_id": sample_id}
# ### REMOVED ###: Collate function is not needed as logic moved to main loop
# def evaluation_collate_fn_tts(...): ...
def main():
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
setup_distributed(rank, world_size)
device = torch.device(f"cuda:{rank}")
logger = setup_logger(rank)
parser = argparse.ArgumentParser(description="Run DDP TTS evaluation for MMadaModelLM on EMOVA dataset.")
parser.add_argument('--train_step', type=int, required=True, help='Checkpoint step to evaluate')
parser.add_argument('--guidance_scale', type=float, default=0.0, help='CFG guidance scale')
parser.add_argument('--timesteps', type=int, default=32, help='Number of generation timesteps for diffusion')
parser.add_argument('--speech_token_length', type=int, default=250, help='Max length of speech tokens to generate')
args, unknown = parser.parse_known_args()
config = get_config()
if rank == 0:
wandb.init(
project="ckpts_grid_libri_emova",
name=f'{config.experiment.name}-TTS-STEP-{args.train_step}',
config=vars(args)
)
text_tokenizer = AutoTokenizer.from_pretrained(config.model.omada.pretrained_model_path, padding_side="left")
if text_tokenizer.pad_token is None:
text_tokenizer.pad_token = text_tokenizer.eos_token
uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True)
if rank == 0:
logger.info("Loading Whisper model for evaluation...")
whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=device)
logger.info("Whisper model loaded.")
logger.info("Loading EMOVA VQ model (vocoder)...")
vq_model = EMOVASpeechTokenizer.from_pretrained(config.model.vq_model_audio.vq_model_name).to(device)
vq_model.eval()
logger.info("EMOVA VQ model loaded.")
# trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{args.train_step}/unwrapped_model/"
# trained_checkpoint_path = "/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-145000/unwrapped_model"
trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1_2nd/checkpoint-45000/unwrapped_model"
if rank == 0:
logger.info(f"Loading trained MMada model from: {trained_checkpoint_path}")
model = OMadaModelLM.from_pretrained(
trained_checkpoint_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json"
).to(device)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
logger.info("โœ… Trained MMada model loaded and wrapped with DDP successfully!")
hf_dataset = get_emova_dataset_tts(logger)
eval_dataset = EMOVATtsEvalDataset(hf_dataset)
sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False)
# ### CORRECTED ###: Remove custom collate_fn, default is sufficient
dataloader = DataLoader(eval_dataset, batch_size=16, sampler=sampler, num_workers=0)
local_results = []
model.eval()
mask_token_id = 126336
if config.get("mask_schedule", None) is not None:
schedule = config.mask_schedule.schedule
schedule_args = config.mask_schedule.get("params", {})
mask_schedule = get_mask_schedule(schedule, **schedule_args)
else:
mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine"))
progress_bar = tqdm(dataloader, desc="Evaluating TTS on EMOVA", disable=(rank != 0))
for batch_idx, batch in enumerate(progress_bar):
if batch is None:
continue
gt_texts = batch["gt_text"]
sample_ids = batch["sample_id"]
# ### CORRECTED & SIMPLIFIED PROMPT PREPARATION ###
prompts = []
for text in gt_texts:
text = text.rsplit("\n", 1)[-1].strip()
chosen_prompt = random.choice(T2S_INSTRUCTION)
full_instruction = f"{text}\n{chosen_prompt}" # Combine instruction and text
prompts.append(full_instruction)
print(prompts[0])
batch_size = len(prompts)
# Using speech_token_length from args
print(args.speech_token_length -1)
audio_tokens = torch.ones((batch_size, args.speech_token_length -1 ), dtype=torch.long, device=device) * mask_token_id # 99 tokens
input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen')
if args.guidance_scale > 0:
uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen')
else:
uncond_input_ids, uncond_attention_mask = None, None
with torch.no_grad():
# ### CORRECTED t2s_generate call with proper arguments ###
output_ids = model.module.t2s_generate(
input_ids=input_ids,
uncond_input_ids=uncond_input_ids,
attention_mask=attention_mask,
uncond_attention_mask=uncond_attention_mask,
guidance_scale=args.guidance_scale,
temperature=1.0, # Hardcoded temperature as example
timesteps=args.timesteps,
noise_schedule=mask_schedule,
noise_type="mask",
seq_len=args.speech_token_length,
uni_prompting=uni_prompting,
config=config,
)
if rank == 0:
for i in range(batch_size):
gt = gt_texts[i].rsplit("\n", 1)[-1].strip()
gen_token_ids = output_ids[i]
# print(gt)
# print(gen_token_ids)
clamped_ids = torch.clamp(gen_token_ids, max=4096 - 1, min=0)
id_list = clamped_ids.cpu().tolist()
speech_unit_str = " ".join(map(str, id_list))
speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")])
output_wav_path = f"/home/work/AIDAS/t2s_logs/tts_output_{sample_ids[i]}.wav"
condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal'
vq_model.decode(
speech_unit_for_decode,
condition=condition,
output_wav_file=output_wav_path
)
whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"})
whisper_text = whisper_result.get("text", "")
local_results.append({
"sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text
})
if i == 0:
logger.info(f"\n--- TTS Example (Batch {batch_idx}) ---")
logger.info(f" ID: {sample_ids[i]}; GT: {gt}; Whisper: {whisper_text}")
logger.info(f" (Audio saved to {output_wav_path})")
wandb.log({
f"Generated Audio/{sample_ids[i]}": wandb.Audio(output_wav_path, caption=f"ID: {sample_ids[i]}\nGT: {gt}\nWhisper: {whisper_text}")
})
all_results = [None] * world_size
dist.all_gather_object(all_results, local_results)
if rank == 0:
final_results = [item for sublist in all_results for item in sublist if item is not None]
if final_results:
groundtruth_text_list = [res["gt_text"] for res in final_results]
recognized_text_list = [res["whisper_text"] for res in final_results]
results_table = wandb.Table(columns=["ID", "Ground Truth Text", "Whisper Transcription"])
for res in final_results: results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"])
wandb.log({"Text-to-Speech Whisper Transcriptions": results_table})
wer, errors, words = calculate_WER(recognized_text_list, groundtruth_text_list)
logger.info(f"Final TTS WER (via Whisper): {wer:.4f} | Word Errors: {int(errors)} | Total Words: {int(words)}")
wandb.log({"TTS WER (via Whisper)": wer, "Total Word Errors": errors, "Total Words": words})
else:
logger.warning("No results were generated to calculate WER.")
# ### CRITICAL FIX: DDP Cleanup MUST be called by all processes ###
cleanup_distributed()
if __name__ == '__main__':
main()