# coding=utf-8 # Copyright 2025 AIDAS Lab # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random import editdistance from functools import partial import re from normalizer import data_utils os.environ["TOKENIZERS_PARALLELISM"] = "true" from tqdm import tqdm import torch import torch.distributed as dist from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP import wandb from datasets import load_dataset from models import OMadaModelLM from training.data import S2T_INSTRUCTION from training.prompting_utils import UniversalPrompting from training.utils import get_config, flatten_omega_conf from transformers import AutoTokenizer import argparse import logging def setup_logger(rank): logger = logging.getLogger(__name__) if logger.hasHandlers(): logger.handlers.clear() formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') ch = logging.StreamHandler() ch.setFormatter(formatter) logger.addHandler(ch) if rank == 0: logger.setLevel(logging.INFO) else: logger.setLevel(logging.WARNING) return logger def calculate_WER(recognized_text_list, groundtruth_text_list): """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" word_num = 0.0 scores = 0.0 for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list): recognized_text = recognized_text.lower() groundtruth_text = groundtruth_text.lower() recognized_text = re.sub(r"[^\w\s']", "", recognized_text) groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text) recognized_word_list = recognized_text.split() groundtruth_word_list = groundtruth_text.split() current_word_num = len(groundtruth_word_list) word_num += current_word_num current_score = editdistance.eval(recognized_word_list, groundtruth_word_list) scores += current_score WER = scores / word_num if word_num > 0 else 0.0 return WER, scores, word_num # ### REMOVED ###: No longer need this function # def get_vq_model_class(model_type): ... def get_emova_dataset(logger): """Loads the EMOVA ASR/TTS evaluation dataset from Hugging Face.""" logger.info("Loading EMOVA dataset (librispeech-asr-tts config)...") dataset = load_dataset("Emova-ollm/emova-asr-tts-eval", "librispeech-asr-tts", split='test') dataset = dataset.filter(lambda example: 'asr' in example['id']) logger.info(f"Dataset loaded successfully. Found {len(dataset)} ASR examples.") return dataset def setup_distributed(rank, world_size): """Initializes the distributed process group.""" dist.init_process_group("gloo", rank=rank, world_size=world_size) def cleanup_distributed(): """Cleans up the distributed process group.""" dist.destroy_process_group() # ### MODIFIED ###: Dataset class now parses speech tokens from string class EMOVAAsrEvalDataset(Dataset): def __init__(self, hf_dataset, text_vocab_size, image_vocab_size): self.hf_dataset = hf_dataset self.text_vocab_size = text_vocab_size self.image_vocab_size = image_vocab_size # Pre-compile the regex for efficiency self.speech_token_pattern = re.compile(r'<\|speech_(\d+)\|>') def __len__(self): return len(self.hf_dataset) def __getitem__(self, idx): example = self.hf_dataset[idx] # Ground truth text is from the 'gpt' turn gt_text = example['conversations'][-1]['value'] sample_id = example['id'] # Audio tokens are in the 'human' turn as a string audio_token_string = example['conversations'][0]['value'] # Parse the string to extract integer token IDs speech_token_ids_str = self.speech_token_pattern.findall(audio_token_string) # print(audio_token_string) # print(speech_token_ids_str) if not speech_token_ids_str: return None # Handle cases with no speech tokens speech_token_ids = torch.tensor([int(s) for s in speech_token_ids_str], dtype=torch.long) # Shift audio token IDs to the correct range for the multimodal model's vocabulary speech_token_ids += self.text_vocab_size + self.image_vocab_size return { # Unsqueeze to add a batch dimension (consistent with original vq_model.encode output) "speech_token_ids": speech_token_ids.unsqueeze(0), "gt_text": gt_text, "sample_id": sample_id } def evaluation_collate_fn(batch, text_tokenizer, uni_prompting, config): batch = [b for b in batch if b is not None] if not batch: return None max_text_len = config.dataset.preprocessing.max_seq_length max_audio_len = config.dataset.preprocessing.max_aud_length + 1 audio_pad_id = 126093 sptids_dict = uni_prompting.sptids_dict batched_input_ids = [] gt_texts = [item["gt_text"] for item in batch] sample_ids = [item["sample_id"] for item in batch] for item in batch: current_audio_tokens = item["speech_token_ids"] task_tensor = sptids_dict['<|s2t|>'].to('cpu').unsqueeze(0) soa_tensor = sptids_dict['<|soa|>'].to('cpu').unsqueeze(0) eoa_tensor = sptids_dict['<|eoa|>'].to('cpu').unsqueeze(0) effective_max_audio = max_audio_len - 3 if current_audio_tokens.shape[1] > effective_max_audio: current_audio_tokens = current_audio_tokens[:, :effective_max_audio] audio_block = torch.cat([task_tensor, soa_tensor, current_audio_tokens, eoa_tensor], dim=1) num_padding = max_audio_len - audio_block.shape[1] if num_padding > 0: padding_tensor = torch.full((1, num_padding), audio_pad_id, dtype=torch.long) padded_audio_block = torch.cat([padding_tensor, audio_block], dim=1) else: padded_audio_block = audio_block chosen_prompt = random.choice(S2T_INSTRUCTION) prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{chosen_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' prompt_encoding = text_tokenizer( prompt_text, max_length=max_text_len, truncation=True, return_tensors="pt" ) prompt_tensor = prompt_encoding.input_ids final_sequence = torch.cat([padded_audio_block, prompt_tensor], dim=1) batched_input_ids.append(final_sequence.squeeze(0)) pad_token_id = 126093 max_len = max(seq.size(0) for seq in batched_input_ids) final_batch = torch.full((len(batched_input_ids), max_len), pad_token_id, dtype=torch.long) for i, seq in enumerate(batched_input_ids): final_batch[i, -len(seq):] = seq return { "input_ids": final_batch, "gt_texts": gt_texts, "sample_ids": sample_ids } def main(): """Main function to run the distributed evaluation.""" rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) setup_distributed(rank, world_size) device = torch.device(f"cuda:{rank}") logger = setup_logger(rank) parser = argparse.ArgumentParser(description="Run DDP evaluation for MMadaModelLM on EMOVA dataset.") parser.add_argument('--train_step', type=int, required=True, help='WIP') parser.add_argument('--remasking', type=str, default='random', help='Remasking Strategy.') parser.add_argument('--generation_step', type=int, default=512, help='WIP') parser.add_argument('--new_tok', type=int, default=128, help='WIP') parser.add_argument('--block_length', type=int, default=64, help='WIP') # parser.add_argument('--ckpt_path', type=str, required=True, help='WIP') args, unknown = parser.parse_known_args() config = get_config() if rank == 0: run_id = config.wandb.get("run_id", None) or wandb.util.generate_id() config.wandb.run_id = run_id wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} wandb.init( project="merging_grid", name=f'{config.experiment.name}-STEP-{args.train_step}-Remasking-{args.remasking}-GS-{args.generation_step}-NT-{args.new_tok}', config=wandb_config, ) text_tokenizer = AutoTokenizer.from_pretrained(config.model.omada.pretrained_model_path, padding_side="left") uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) # ### REMOVED ###: VQ Model is not needed anymore # vq_model_class = get_vq_model_class(config.model.vq_model_audio.type) # vq_model = vq_model_class.from_pretrained(config.model.vq_model_audio.vq_model_name).to(device) # vq_model.requires_grad_(False) # vq_model.eval() train_step = args.train_step # trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{train_step}/unwrapped_model/" trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1_2nd/checkpoint-50000/unwrapped_model" # trained_checkpoint_path = args.ckpt_path if rank == 0: logger.info(f"Loading trained model from: {trained_checkpoint_path}") model = OMadaModelLM.from_pretrained( trained_checkpoint_path, trust_remote_code=True, torch_dtype=torch.bfloat16, config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" ).to(device) print("BEFORE DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) print("AFTER DDP") if rank == 0: logger.info("✅ Trained model loaded and wrapped with DDP successfully!") text_vocab_size = len(uni_prompting.text_tokenizer) image_vocab_size = config.model.omada.codebook_size # --- Setup DataLoader --- hf_dataset = get_emova_dataset(logger) # ### MODIFIED ###: Pass only necessary arguments to the dataset class eval_dataset = EMOVAAsrEvalDataset(hf_dataset, text_vocab_size, image_vocab_size) sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) collate_for_eval = partial( evaluation_collate_fn, text_tokenizer=text_tokenizer, uni_prompting=uni_prompting, config=config ) dataloader = DataLoader( eval_dataset, batch_size=16, sampler=sampler, num_workers=0, collate_fn=collate_for_eval, pin_memory=True ) # --- Evaluation Loop --- local_results = [] model.eval() progress_bar = tqdm(dataloader, desc="Evaluating on EMOVA ASR", disable=(rank != 0)) for batch in progress_bar: if batch is None: continue input_ids = batch["input_ids"].to(device) gt_texts = batch["gt_texts"] sample_ids = batch["sample_ids"] # print(input_ids) # print(gt_texts) # print(sample_ids) with torch.no_grad(): output_ids = model.module.mmu_generate(input_ids, max_new_tokens=args.new_tok, steps=args.generation_step, block_length=args.block_length, remasking=args.remasking) decoded_texts = text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) for i in range(len(decoded_texts)): local_results.append({ "sample_id": sample_ids[i], "gt_text": gt_texts[i], "decoded_text": decoded_texts[i] }) if rank == 0 and i == 0 and len(local_results) % 10 == 1: logger.info(f"\n--- Example ---") logger.info(f" ID: {sample_ids[i]}") logger.info(f" GT: {gt_texts[i]}") logger.info(f" PD: {decoded_texts[i]}") logger.info(f"-----------------\n") # --- Gather Results from All GPUs --- all_results = [None] * world_size dist.all_gather_object(all_results, local_results) # --- Final Processing and Logging (only on rank 0) --- if rank == 0: logger.info("Gathering and processing results from all GPUs...") final_results = [item for sublist in all_results for item in sublist] groundtruth_text_list = [data_utils.normalizer(res["gt_text"]) for res in final_results] recognized_text_list = [data_utils.normalizer(res["decoded_text"]) for res in final_results] results_table = wandb.Table(columns=["ID", "Ground Truth", "Response"]) for res in final_results: results_table.add_data(res["sample_id"], res["gt_text"], res["decoded_text"]) wandb.log({"Speech-to-Text Response Examples": results_table}) wer, errors, words = calculate_WER(recognized_text_list, groundtruth_text_list) logger.info(f"Final WER (EMOVA test): {wer:.4f} | Word Errors: {int(errors)} | Total Words: {int(words)}") wandb.log({ "WER": wer, "Total Word Errors": errors, "Total Words": words }) # --- Cleanup --- if rank == 0: wandb.finish() cleanup_distributed() if __name__ == '__main__': main()