Spaces:
Running
on
Zero
Running
on
Zero
| # coding=utf-8 | |
| # Copyright 2025 AIDAS Lab | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import random | |
| import editdistance | |
| from functools import partial | |
| import re | |
| from normalizer import data_utils | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| from tqdm import tqdm | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import wandb | |
| from datasets import load_dataset | |
| from models import OMadaModelLM | |
| from training.data import S2T_INSTRUCTION | |
| from training.prompting_utils import UniversalPrompting | |
| from training.utils import get_config, flatten_omega_conf | |
| from transformers import AutoTokenizer | |
| import argparse | |
| import logging | |
| def setup_logger(rank): | |
| logger = logging.getLogger(__name__) | |
| if logger.hasHandlers(): | |
| logger.handlers.clear() | |
| formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | |
| ch = logging.StreamHandler() | |
| ch.setFormatter(formatter) | |
| logger.addHandler(ch) | |
| if rank == 0: | |
| logger.setLevel(logging.INFO) | |
| else: | |
| logger.setLevel(logging.WARNING) | |
| return logger | |
| def calculate_WER(recognized_text_list, groundtruth_text_list): | |
| """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" | |
| word_num = 0.0 | |
| scores = 0.0 | |
| for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list): | |
| recognized_text = recognized_text.lower() | |
| groundtruth_text = groundtruth_text.lower() | |
| recognized_text = re.sub(r"[^\w\s']", "", recognized_text) | |
| groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text) | |
| recognized_word_list = recognized_text.split() | |
| groundtruth_word_list = groundtruth_text.split() | |
| current_word_num = len(groundtruth_word_list) | |
| word_num += current_word_num | |
| current_score = editdistance.eval(recognized_word_list, groundtruth_word_list) | |
| scores += current_score | |
| WER = scores / word_num if word_num > 0 else 0.0 | |
| return WER, scores, word_num | |
| # ### REMOVED ###: No longer need this function | |
| # def get_vq_model_class(model_type): ... | |
| def get_emova_dataset(logger): | |
| """Loads the EMOVA ASR/TTS evaluation dataset from Hugging Face.""" | |
| logger.info("Loading EMOVA dataset (librispeech-asr-tts config)...") | |
| dataset = load_dataset("Emova-ollm/emova-asr-tts-eval", "librispeech-asr-tts", split='test') | |
| dataset = dataset.filter(lambda example: 'asr' in example['id']) | |
| logger.info(f"Dataset loaded successfully. Found {len(dataset)} ASR examples.") | |
| return dataset | |
| def setup_distributed(rank, world_size): | |
| """Initializes the distributed process group.""" | |
| dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
| def cleanup_distributed(): | |
| """Cleans up the distributed process group.""" | |
| dist.destroy_process_group() | |
| # ### MODIFIED ###: Dataset class now parses speech tokens from string | |
| class EMOVAAsrEvalDataset(Dataset): | |
| def __init__(self, hf_dataset, text_vocab_size, image_vocab_size): | |
| self.hf_dataset = hf_dataset | |
| self.text_vocab_size = text_vocab_size | |
| self.image_vocab_size = image_vocab_size | |
| # Pre-compile the regex for efficiency | |
| self.speech_token_pattern = re.compile(r'<\|speech_(\d+)\|>') | |
| def __len__(self): | |
| return len(self.hf_dataset) | |
| def __getitem__(self, idx): | |
| example = self.hf_dataset[idx] | |
| # Ground truth text is from the 'gpt' turn | |
| gt_text = example['conversations'][-1]['value'] | |
| sample_id = example['id'] | |
| # Audio tokens are in the 'human' turn as a string | |
| audio_token_string = example['conversations'][0]['value'] | |
| # Parse the string to extract integer token IDs | |
| speech_token_ids_str = self.speech_token_pattern.findall(audio_token_string) | |
| # print(audio_token_string) | |
| # print(speech_token_ids_str) | |
| if not speech_token_ids_str: | |
| return None # Handle cases with no speech tokens | |
| speech_token_ids = torch.tensor([int(s) for s in speech_token_ids_str], dtype=torch.long) | |
| # Shift audio token IDs to the correct range for the multimodal model's vocabulary | |
| speech_token_ids += self.text_vocab_size + self.image_vocab_size | |
| return { | |
| # Unsqueeze to add a batch dimension (consistent with original vq_model.encode output) | |
| "speech_token_ids": speech_token_ids.unsqueeze(0), | |
| "gt_text": gt_text, | |
| "sample_id": sample_id | |
| } | |
| def evaluation_collate_fn(batch, text_tokenizer, uni_prompting, config): | |
| batch = [b for b in batch if b is not None] | |
| if not batch: | |
| return None | |
| max_text_len = config.dataset.preprocessing.max_seq_length | |
| max_audio_len = config.dataset.preprocessing.max_aud_length + 1 | |
| audio_pad_id = 126093 | |
| sptids_dict = uni_prompting.sptids_dict | |
| batched_input_ids = [] | |
| gt_texts = [item["gt_text"] for item in batch] | |
| sample_ids = [item["sample_id"] for item in batch] | |
| for item in batch: | |
| current_audio_tokens = item["speech_token_ids"] | |
| task_tensor = sptids_dict['<|s2t|>'].to('cpu').unsqueeze(0) | |
| soa_tensor = sptids_dict['<|soa|>'].to('cpu').unsqueeze(0) | |
| eoa_tensor = sptids_dict['<|eoa|>'].to('cpu').unsqueeze(0) | |
| effective_max_audio = max_audio_len - 3 | |
| if current_audio_tokens.shape[1] > effective_max_audio: | |
| current_audio_tokens = current_audio_tokens[:, :effective_max_audio] | |
| audio_block = torch.cat([task_tensor, soa_tensor, current_audio_tokens, eoa_tensor], dim=1) | |
| num_padding = max_audio_len - audio_block.shape[1] | |
| if num_padding > 0: | |
| padding_tensor = torch.full((1, num_padding), audio_pad_id, dtype=torch.long) | |
| padded_audio_block = torch.cat([padding_tensor, audio_block], dim=1) | |
| else: | |
| padded_audio_block = audio_block | |
| chosen_prompt = random.choice(S2T_INSTRUCTION) | |
| prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{chosen_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' | |
| prompt_encoding = text_tokenizer( | |
| prompt_text, | |
| max_length=max_text_len, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| prompt_tensor = prompt_encoding.input_ids | |
| final_sequence = torch.cat([padded_audio_block, prompt_tensor], dim=1) | |
| batched_input_ids.append(final_sequence.squeeze(0)) | |
| pad_token_id = 126093 | |
| max_len = max(seq.size(0) for seq in batched_input_ids) | |
| final_batch = torch.full((len(batched_input_ids), max_len), | |
| pad_token_id, | |
| dtype=torch.long) | |
| for i, seq in enumerate(batched_input_ids): | |
| final_batch[i, -len(seq):] = seq | |
| return { | |
| "input_ids": final_batch, | |
| "gt_texts": gt_texts, | |
| "sample_ids": sample_ids | |
| } | |
| def main(): | |
| """Main function to run the distributed evaluation.""" | |
| rank = int(os.environ['RANK']) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| setup_distributed(rank, world_size) | |
| device = torch.device(f"cuda:{rank}") | |
| logger = setup_logger(rank) | |
| parser = argparse.ArgumentParser(description="Run DDP evaluation for MMadaModelLM on EMOVA dataset.") | |
| parser.add_argument('--train_step', type=int, required=True, help='WIP') | |
| parser.add_argument('--remasking', type=str, default='random', help='Remasking Strategy.') | |
| parser.add_argument('--generation_step', type=int, default=512, help='WIP') | |
| parser.add_argument('--new_tok', type=int, default=128, help='WIP') | |
| parser.add_argument('--block_length', type=int, default=64, help='WIP') | |
| # parser.add_argument('--ckpt_path', type=str, required=True, help='WIP') | |
| args, unknown = parser.parse_known_args() | |
| config = get_config() | |
| if rank == 0: | |
| run_id = config.wandb.get("run_id", None) or wandb.util.generate_id() | |
| config.wandb.run_id = run_id | |
| wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} | |
| wandb.init( | |
| project="merging_grid", | |
| name=f'{config.experiment.name}-STEP-{args.train_step}-Remasking-{args.remasking}-GS-{args.generation_step}-NT-{args.new_tok}', | |
| config=wandb_config, | |
| ) | |
| text_tokenizer = AutoTokenizer.from_pretrained(config.model.omada.pretrained_model_path, padding_side="left") | |
| uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, | |
| special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), | |
| ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) | |
| # ### REMOVED ###: VQ Model is not needed anymore | |
| # vq_model_class = get_vq_model_class(config.model.vq_model_audio.type) | |
| # vq_model = vq_model_class.from_pretrained(config.model.vq_model_audio.vq_model_name).to(device) | |
| # vq_model.requires_grad_(False) | |
| # vq_model.eval() | |
| train_step = args.train_step | |
| # trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{train_step}/unwrapped_model/" | |
| trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1_2nd/checkpoint-50000/unwrapped_model" | |
| # trained_checkpoint_path = args.ckpt_path | |
| if rank == 0: | |
| logger.info(f"Loading trained model from: {trained_checkpoint_path}") | |
| model = OMadaModelLM.from_pretrained( | |
| trained_checkpoint_path, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" | |
| ).to(device) | |
| print("BEFORE DDP") | |
| model = DDP(model, device_ids=[rank], find_unused_parameters=True) | |
| print("AFTER DDP") | |
| if rank == 0: | |
| logger.info("✅ Trained model loaded and wrapped with DDP successfully!") | |
| text_vocab_size = len(uni_prompting.text_tokenizer) | |
| image_vocab_size = config.model.omada.codebook_size | |
| # --- Setup DataLoader --- | |
| hf_dataset = get_emova_dataset(logger) | |
| # ### MODIFIED ###: Pass only necessary arguments to the dataset class | |
| eval_dataset = EMOVAAsrEvalDataset(hf_dataset, text_vocab_size, image_vocab_size) | |
| sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) | |
| collate_for_eval = partial( | |
| evaluation_collate_fn, | |
| text_tokenizer=text_tokenizer, | |
| uni_prompting=uni_prompting, | |
| config=config | |
| ) | |
| dataloader = DataLoader( | |
| eval_dataset, | |
| batch_size=16, | |
| sampler=sampler, | |
| num_workers=0, | |
| collate_fn=collate_for_eval, | |
| pin_memory=True | |
| ) | |
| # --- Evaluation Loop --- | |
| local_results = [] | |
| model.eval() | |
| progress_bar = tqdm(dataloader, desc="Evaluating on EMOVA ASR", disable=(rank != 0)) | |
| for batch in progress_bar: | |
| if batch is None: | |
| continue | |
| input_ids = batch["input_ids"].to(device) | |
| gt_texts = batch["gt_texts"] | |
| sample_ids = batch["sample_ids"] | |
| # print(input_ids) | |
| # print(gt_texts) | |
| # print(sample_ids) | |
| with torch.no_grad(): | |
| output_ids = model.module.mmu_generate(input_ids, max_new_tokens=args.new_tok, steps=args.generation_step, block_length=args.block_length, remasking=args.remasking) | |
| decoded_texts = text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) | |
| for i in range(len(decoded_texts)): | |
| local_results.append({ | |
| "sample_id": sample_ids[i], | |
| "gt_text": gt_texts[i], | |
| "decoded_text": decoded_texts[i] | |
| }) | |
| if rank == 0 and i == 0 and len(local_results) % 10 == 1: | |
| logger.info(f"\n--- Example ---") | |
| logger.info(f" ID: {sample_ids[i]}") | |
| logger.info(f" GT: {gt_texts[i]}") | |
| logger.info(f" PD: {decoded_texts[i]}") | |
| logger.info(f"-----------------\n") | |
| # --- Gather Results from All GPUs --- | |
| all_results = [None] * world_size | |
| dist.all_gather_object(all_results, local_results) | |
| # --- Final Processing and Logging (only on rank 0) --- | |
| if rank == 0: | |
| logger.info("Gathering and processing results from all GPUs...") | |
| final_results = [item for sublist in all_results for item in sublist] | |
| groundtruth_text_list = [data_utils.normalizer(res["gt_text"]) for res in final_results] | |
| recognized_text_list = [data_utils.normalizer(res["decoded_text"]) for res in final_results] | |
| results_table = wandb.Table(columns=["ID", "Ground Truth", "Response"]) | |
| for res in final_results: | |
| results_table.add_data(res["sample_id"], res["gt_text"], res["decoded_text"]) | |
| wandb.log({"Speech-to-Text Response Examples": results_table}) | |
| wer, errors, words = calculate_WER(recognized_text_list, groundtruth_text_list) | |
| logger.info(f"Final WER (EMOVA test): {wer:.4f} | Word Errors: {int(errors)} | Total Words: {int(words)}") | |
| wandb.log({ | |
| "WER": wer, | |
| "Total Word Errors": errors, | |
| "Total Words": words | |
| }) | |
| # --- Cleanup --- | |
| if rank == 0: | |
| wandb.finish() | |
| cleanup_distributed() | |
| if __name__ == '__main__': | |
| main() |