Spaces:
Running
on
Zero
Running
on
Zero
| # coding=utf-8 | |
| # Copyright 2025 AIDAS Lab | |
| import os | |
| import random | |
| import editdistance | |
| from functools import partial | |
| import re | |
| import soundfile as sf | |
| import numpy as np | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| from tqdm import tqdm | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import wandb | |
| from datasets import load_dataset | |
| from models import OMadaModelLM | |
| from training.data import T2S_INSTRUCTION # T2S_INSTRUCTION import | |
| from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer | |
| from training.prompting_utils import UniversalPrompting | |
| from training.utils import get_config, flatten_omega_conf | |
| from models import get_mask_schedule | |
| from transformers import AutoTokenizer, pipeline | |
| import argparse | |
| import logging | |
| # --- (setup_logger, calculate_WER, get_emova_dataset_tts, EMOVATtsEvalDataset, setup_distributed, cleanup_distributed ํจ์๋ ์ด์ ๊ณผ ๋์ผ) --- | |
| def setup_logger(rank): | |
| logger = logging.getLogger(__name__) | |
| if logger.hasHandlers(): | |
| logger.handlers.clear() | |
| formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | |
| ch = logging.StreamHandler() | |
| ch.setFormatter(formatter) | |
| logger.addHandler(ch) | |
| if rank == 0: | |
| logger.setLevel(logging.INFO) | |
| else: | |
| logger.setLevel(logging.WARNING) | |
| return logger | |
| def calculate_WER(recognized_text_list, groundtruth_text_list): | |
| word_num = 0.0 | |
| scores = 0.0 | |
| for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list): | |
| recognized_text = recognized_text.lower(); groundtruth_text = groundtruth_text.lower() | |
| recognized_text = re.sub(r"[^\w\s']", "", recognized_text); groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text) | |
| recognized_word_list = recognized_text.split(); groundtruth_word_list = groundtruth_text.split() | |
| current_word_num = len(groundtruth_word_list); word_num += current_word_num | |
| current_score = editdistance.eval(recognized_word_list, groundtruth_word_list); scores += current_score | |
| WER = scores / word_num if word_num > 0 else 0.0 | |
| return WER, scores, word_num | |
| def get_emova_dataset_tts(logger): | |
| logger.info("Loading EMOVA dataset (librispeech-asr-tts config) for TTS...") | |
| dataset = load_dataset("Emova-ollm/emova-asr-tts-eval", "librispeech-asr-tts", split='test') | |
| original_count = len(dataset) | |
| dataset = dataset.filter( | |
| lambda example: 'tts' in example['id'] and '<speech' not in example['conversations'][0]['value'] | |
| ) | |
| logger.info(f"Original dataset size: {original_count}. Filtered for TTS tasks. New size: {len(dataset)}") | |
| return dataset | |
| def setup_distributed(rank, world_size): | |
| dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
| def cleanup_distributed(): | |
| dist.destroy_process_group() | |
| class EMOVATtsEvalDataset(Dataset): | |
| def __init__(self, hf_dataset): | |
| self.hf_dataset = hf_dataset | |
| def __len__(self): | |
| return len(self.hf_dataset) | |
| def __getitem__(self, idx): | |
| example = self.hf_dataset[idx] | |
| gt_text = example['conversations'][0]['value'] | |
| sample_id = example['id'] | |
| return {"gt_text": gt_text, "sample_id": sample_id} | |
| # ### REMOVED ###: Collate function is not needed as logic moved to main loop | |
| # def evaluation_collate_fn_tts(...): ... | |
| def main(): | |
| rank = int(os.environ['RANK']) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| setup_distributed(rank, world_size) | |
| device = torch.device(f"cuda:{rank}") | |
| logger = setup_logger(rank) | |
| parser = argparse.ArgumentParser(description="Run DDP TTS evaluation for MMadaModelLM on EMOVA dataset.") | |
| parser.add_argument('--train_step', type=int, required=True, help='Checkpoint step to evaluate') | |
| parser.add_argument('--guidance_scale', type=float, default=0.0, help='CFG guidance scale') | |
| parser.add_argument('--timesteps', type=int, default=32, help='Number of generation timesteps for diffusion') | |
| parser.add_argument('--speech_token_length', type=int, default=250, help='Max length of speech tokens to generate') | |
| args, unknown = parser.parse_known_args() | |
| config = get_config() | |
| if rank == 0: | |
| wandb.init( | |
| project="ckpts_grid_libri_emova", | |
| name=f'{config.experiment.name}-TTS-STEP-{args.train_step}', | |
| config=vars(args) | |
| ) | |
| text_tokenizer = AutoTokenizer.from_pretrained(config.model.omada.pretrained_model_path, padding_side="left") | |
| if text_tokenizer.pad_token is None: | |
| text_tokenizer.pad_token = text_tokenizer.eos_token | |
| uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, | |
| special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), | |
| ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) | |
| if rank == 0: | |
| logger.info("Loading Whisper model for evaluation...") | |
| whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=device) | |
| logger.info("Whisper model loaded.") | |
| logger.info("Loading EMOVA VQ model (vocoder)...") | |
| vq_model = EMOVASpeechTokenizer.from_pretrained(config.model.vq_model_audio.vq_model_name).to(device) | |
| vq_model.eval() | |
| logger.info("EMOVA VQ model loaded.") | |
| # trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{args.train_step}/unwrapped_model/" | |
| # trained_checkpoint_path = "/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-145000/unwrapped_model" | |
| trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1_2nd/checkpoint-45000/unwrapped_model" | |
| if rank == 0: | |
| logger.info(f"Loading trained MMada model from: {trained_checkpoint_path}") | |
| model = OMadaModelLM.from_pretrained( | |
| trained_checkpoint_path, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" | |
| ).to(device) | |
| model = DDP(model, device_ids=[rank], find_unused_parameters=True) | |
| logger.info("โ Trained MMada model loaded and wrapped with DDP successfully!") | |
| hf_dataset = get_emova_dataset_tts(logger) | |
| eval_dataset = EMOVATtsEvalDataset(hf_dataset) | |
| sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) | |
| # ### CORRECTED ###: Remove custom collate_fn, default is sufficient | |
| dataloader = DataLoader(eval_dataset, batch_size=16, sampler=sampler, num_workers=0) | |
| local_results = [] | |
| model.eval() | |
| mask_token_id = 126336 | |
| if config.get("mask_schedule", None) is not None: | |
| schedule = config.mask_schedule.schedule | |
| schedule_args = config.mask_schedule.get("params", {}) | |
| mask_schedule = get_mask_schedule(schedule, **schedule_args) | |
| else: | |
| mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) | |
| progress_bar = tqdm(dataloader, desc="Evaluating TTS on EMOVA", disable=(rank != 0)) | |
| for batch_idx, batch in enumerate(progress_bar): | |
| if batch is None: | |
| continue | |
| gt_texts = batch["gt_text"] | |
| sample_ids = batch["sample_id"] | |
| # ### CORRECTED & SIMPLIFIED PROMPT PREPARATION ### | |
| prompts = [] | |
| for text in gt_texts: | |
| text = text.rsplit("\n", 1)[-1].strip() | |
| chosen_prompt = random.choice(T2S_INSTRUCTION) | |
| full_instruction = f"{text}\n{chosen_prompt}" # Combine instruction and text | |
| prompts.append(full_instruction) | |
| print(prompts[0]) | |
| batch_size = len(prompts) | |
| # Using speech_token_length from args | |
| print(args.speech_token_length -1) | |
| audio_tokens = torch.ones((batch_size, args.speech_token_length -1 ), dtype=torch.long, device=device) * mask_token_id # 99 tokens | |
| input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') | |
| if args.guidance_scale > 0: | |
| uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') | |
| else: | |
| uncond_input_ids, uncond_attention_mask = None, None | |
| with torch.no_grad(): | |
| # ### CORRECTED t2s_generate call with proper arguments ### | |
| output_ids = model.module.t2s_generate( | |
| input_ids=input_ids, | |
| uncond_input_ids=uncond_input_ids, | |
| attention_mask=attention_mask, | |
| uncond_attention_mask=uncond_attention_mask, | |
| guidance_scale=args.guidance_scale, | |
| temperature=1.0, # Hardcoded temperature as example | |
| timesteps=args.timesteps, | |
| noise_schedule=mask_schedule, | |
| noise_type="mask", | |
| seq_len=args.speech_token_length, | |
| uni_prompting=uni_prompting, | |
| config=config, | |
| ) | |
| if rank == 0: | |
| for i in range(batch_size): | |
| gt = gt_texts[i].rsplit("\n", 1)[-1].strip() | |
| gen_token_ids = output_ids[i] | |
| # print(gt) | |
| # print(gen_token_ids) | |
| clamped_ids = torch.clamp(gen_token_ids, max=4096 - 1, min=0) | |
| id_list = clamped_ids.cpu().tolist() | |
| speech_unit_str = " ".join(map(str, id_list)) | |
| speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) | |
| output_wav_path = f"/home/work/AIDAS/t2s_logs/tts_output_{sample_ids[i]}.wav" | |
| condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' | |
| vq_model.decode( | |
| speech_unit_for_decode, | |
| condition=condition, | |
| output_wav_file=output_wav_path | |
| ) | |
| whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) | |
| whisper_text = whisper_result.get("text", "") | |
| local_results.append({ | |
| "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text | |
| }) | |
| if i == 0: | |
| logger.info(f"\n--- TTS Example (Batch {batch_idx}) ---") | |
| logger.info(f" ID: {sample_ids[i]}; GT: {gt}; Whisper: {whisper_text}") | |
| logger.info(f" (Audio saved to {output_wav_path})") | |
| wandb.log({ | |
| f"Generated Audio/{sample_ids[i]}": wandb.Audio(output_wav_path, caption=f"ID: {sample_ids[i]}\nGT: {gt}\nWhisper: {whisper_text}") | |
| }) | |
| all_results = [None] * world_size | |
| dist.all_gather_object(all_results, local_results) | |
| if rank == 0: | |
| final_results = [item for sublist in all_results for item in sublist if item is not None] | |
| if final_results: | |
| groundtruth_text_list = [res["gt_text"] for res in final_results] | |
| recognized_text_list = [res["whisper_text"] for res in final_results] | |
| results_table = wandb.Table(columns=["ID", "Ground Truth Text", "Whisper Transcription"]) | |
| for res in final_results: results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"]) | |
| wandb.log({"Text-to-Speech Whisper Transcriptions": results_table}) | |
| wer, errors, words = calculate_WER(recognized_text_list, groundtruth_text_list) | |
| logger.info(f"Final TTS WER (via Whisper): {wer:.4f} | Word Errors: {int(errors)} | Total Words: {int(words)}") | |
| wandb.log({"TTS WER (via Whisper)": wer, "Total Word Errors": errors, "Total Words": words}) | |
| else: | |
| logger.warning("No results were generated to calculate WER.") | |
| # ### CRITICAL FIX: DDP Cleanup MUST be called by all processes ### | |
| cleanup_distributed() | |
| if __name__ == '__main__': | |
| main() |