Spaces:
Running
on
Zero
Running
on
Zero
| # coding=utf-8 | |
| # Copyright 2025 AIDAS Lab | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import random | |
| import editdistance | |
| from functools import partial | |
| from normalizer import data_utils | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| from tqdm import tqdm | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import wandb | |
| from datasets import load_dataset | |
| from models import MMadaModelLM | |
| from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer | |
| from training.data import S2T_INSTRUCTION | |
| from training.prompting_utils import UniversalPrompting | |
| from training.utils import get_config, flatten_omega_conf | |
| from transformers import AutoTokenizer | |
| import argparse | |
| import logging | |
| import re | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| from tqdm import tqdm | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import wandb | |
| from datasets import load_dataset | |
| from models import MMadaModelLM | |
| from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer | |
| from training.data import S2T_INSTRUCTION | |
| from training.prompting_utils import UniversalPrompting | |
| from training.utils import get_config, flatten_omega_conf | |
| from transformers import AutoTokenizer | |
| def setup_logger(rank): | |
| logger = logging.getLogger(__name__) | |
| # 핸들러 중복 추가 방지 | |
| if logger.hasHandlers(): | |
| logger.handlers.clear() | |
| formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | |
| ch = logging.StreamHandler() | |
| ch.setFormatter(formatter) | |
| logger.addHandler(ch) | |
| if rank == 0: | |
| logger.setLevel(logging.INFO) | |
| else: | |
| logger.setLevel(logging.WARNING) | |
| return logger | |
| def calculate_WER(recognized_text_list, groundtruth_text_list): | |
| """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" | |
| word_num = 0.0 | |
| scores = 0.0 | |
| for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list): | |
| recognized_text = recognized_text.lower() | |
| groundtruth_text = groundtruth_text.lower() | |
| recognized_text = re.sub(r"[^\w\s']", "", recognized_text) | |
| groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text) | |
| recognized_word_list = recognized_text.split() | |
| groundtruth_word_list = groundtruth_text.split() | |
| current_word_num = len(groundtruth_word_list) | |
| word_num += current_word_num | |
| current_score = editdistance.eval(recognized_word_list, groundtruth_word_list) | |
| scores += current_score | |
| WER = scores / word_num if word_num > 0 else 0.0 | |
| return WER, scores, word_num | |
| def get_vq_model_class(model_type): | |
| """Returns the speech tokenizer model class based on the model type.""" | |
| if model_type == "magvitv2": | |
| raise NotImplementedError("MAGVITv2 is not implemented in this script.") | |
| elif model_type == "emova": | |
| return EMOVASpeechTokenizer | |
| else: | |
| raise ValueError(f"model_type {model_type} not supported.") | |
| def get_librispeech_dataset(logger): | |
| """Loads the Librispeech ASR dataset (test-clean split) from Hugging Face.""" | |
| logger.info("Loading EMOVA dataset (clean/test)...") | |
| dataset = load_dataset("Emova-ollm/emova-asr-tts-eval/", "librispeech-asr-tts")['test'] | |
| logger.info("Dataset loaded successfully.") | |
| return dataset | |
| def form_ann_rst_list(ann, results, key): | |
| ann_dict = {} | |
| for item in ann: | |
| if key in item['id']: | |
| ann_dict[item['id']] = item['conversations'][-1]['value'] | |
| rst_dict = {} | |
| for item in results: | |
| if key in item['id']: | |
| rst_dict[item['id']] = item['text'] | |
| return ann_dict, rst_dict | |
| # --- DDP Setup and Cleanup Functions --- | |
| def setup_distributed(rank, world_size): | |
| """Initializes the distributed process group.""" | |
| dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
| def cleanup_distributed(): | |
| """Cleans up the distributed process group.""" | |
| dist.destroy_process_group() | |
| # --- Custom Dataset and Collate Function --- | |
| class LibrispeechEvalDataset(Dataset): | |
| def __init__(self, hf_dataset, root_path, vq_model, text_vocab_size, image_vocab_size): | |
| self.hf_dataset = hf_dataset | |
| self.root_path = root_path | |
| self.vq_model = vq_model | |
| self.text_vocab_size = text_vocab_size | |
| self.image_vocab_size = image_vocab_size | |
| def __len__(self): | |
| return len(self.hf_dataset) | |
| def __getitem__(self, idx): | |
| example = self.hf_dataset[idx] | |
| gt_text = example['text'] | |
| sample_id = example['id'] | |
| speaker_id, chapter_id, _ = sample_id.split('-') | |
| audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") | |
| if not os.path.exists(audio_path): | |
| return None | |
| speech_token_ids = self.vq_model.encode(audio_path) | |
| speech_token_ids += self.text_vocab_size + self.image_vocab_size | |
| return { | |
| "speech_token_ids": speech_token_ids, | |
| "gt_text": gt_text, | |
| "sample_id": sample_id | |
| } | |
| def evaluation_collate_fn(batch, text_tokenizer, uni_prompting, config): | |
| batch = [b for b in batch if b is not None] | |
| if not batch: | |
| return None | |
| device = batch[0]["speech_token_ids"].device | |
| max_text_len = config.dataset.preprocessing.max_seq_length | |
| max_audio_len = config.dataset.preprocessing.max_aud_length + 1 | |
| audio_pad_id = 126093 | |
| sptids_dict = uni_prompting.sptids_dict | |
| batched_input_ids = [] | |
| gt_texts = [item["gt_text"] for item in batch] | |
| sample_ids = [item["sample_id"] for item in batch] | |
| for item in batch: | |
| current_audio_tokens = item["speech_token_ids"].to(device) | |
| task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) | |
| soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) | |
| eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) | |
| effective_max_audio = max_audio_len - 3 | |
| if current_audio_tokens.shape[1] > effective_max_audio: | |
| current_audio_tokens = current_audio_tokens[:, :effective_max_audio] | |
| audio_block = torch.cat([task_tensor, soa_tensor, current_audio_tokens, eoa_tensor], dim=1) | |
| num_padding = max_audio_len - audio_block.shape[1] | |
| if num_padding > 0: | |
| padding_tensor = torch.full((1, num_padding), audio_pad_id, dtype=torch.long, device=device) | |
| padded_audio_block = torch.cat([padding_tensor, audio_block], dim=1) | |
| else: | |
| padded_audio_block = audio_block | |
| chosen_prompt = random.choice(S2T_INSTRUCTION) | |
| prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{chosen_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' | |
| prompt_encoding = text_tokenizer( | |
| prompt_text, | |
| max_length=max_text_len, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| prompt_tensor = prompt_encoding.input_ids.to(device) | |
| final_sequence = torch.cat([padded_audio_block, prompt_tensor], dim=1) | |
| batched_input_ids.append(final_sequence.squeeze(0)) | |
| pad_token_id = 126093 | |
| max_len = max(seq.size(0) for seq in batched_input_ids) | |
| final_batch = torch.full((len(batched_input_ids), max_len), | |
| pad_token_id, | |
| dtype=torch.long, | |
| device=device) | |
| for i, seq in enumerate(batched_input_ids): | |
| final_batch[i, -len(seq):] = seq | |
| return { | |
| "input_ids": final_batch, | |
| "gt_texts": gt_texts, | |
| "sample_ids": sample_ids | |
| } | |
| def main(): | |
| """Main function to run the distributed evaluation.""" | |
| rank = int(os.environ['RANK']) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| setup_distributed(rank, world_size) | |
| device = torch.device(f"cuda:{rank}") | |
| logger = setup_logger(rank) | |
| parser = argparse.ArgumentParser(description="Run DDP evaluation for MMadaModelLM.") | |
| parser.add_argument('--train_step', type=int, required=True, help='The training step of the checkpoint to evaluate.') | |
| parser.add_argument('--remasking', type=str, default='random', help='Remasking Strategy.') | |
| parser.add_argument('--generation_step', type=int, default=512, help='The training step of the checkpoint to evaluate.') | |
| parser.add_argument('--new_tok', type=int, default=256, help='The training step of the checkpoint to evaluate.') | |
| args, unknown = parser.parse_known_args() | |
| config = get_config() | |
| if rank == 0: | |
| run_id = config.wandb.get("run_id", None) or wandb.util.generate_id() | |
| config.wandb.run_id = run_id | |
| wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} | |
| wandb.init( | |
| project="librispeech_test-clean", | |
| name=config.experiment.name + f'_STEP-{args.train_step}_Remasking-{args.remasking}_GS-{args.generation_step}_NT-{args.new_tok}', | |
| config=wandb_config, | |
| ) | |
| text_tokenizer = AutoTokenizer.from_pretrained(config.model.omada.pretrained_model_path, padding_side="left") | |
| uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, | |
| special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), | |
| ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) | |
| vq_model_class = get_vq_model_class(config.model.vq_model_audio.type) | |
| vq_model = vq_model_class.from_pretrained(config.model.vq_model_audio.vq_model_name).to(device) | |
| vq_model.requires_grad_(False) | |
| vq_model.eval() | |
| train_step = args.train_step | |
| trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{train_step}/unwrapped_model/" | |
| if rank == 0: | |
| logger.info(f"Loading trained model from: {trained_checkpoint_path}") | |
| model = MMadaModelLM.from_pretrained( | |
| trained_checkpoint_path, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" | |
| ).to(device) | |
| model = DDP(model, device_ids=[rank]) | |
| if rank == 0: | |
| logger.info("✅ Trained model loaded and wrapped with DDP successfully!") | |
| text_vocab_size = len(uni_prompting.text_tokenizer) | |
| image_vocab_size = config.model.omada.codebook_size | |
| # --- Setup DataLoader --- | |
| hf_dataset = get_librispeech_dataset(logger) | |
| root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean" | |
| eval_dataset = LibrispeechEvalDataset(hf_dataset, root_path, vq_model, text_vocab_size, image_vocab_size) | |
| sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) | |
| collate_for_eval = partial( | |
| evaluation_collate_fn, | |
| text_tokenizer=text_tokenizer, | |
| uni_prompting=uni_prompting, | |
| config=config | |
| ) | |
| dataloader = DataLoader( | |
| eval_dataset, | |
| batch_size=16, | |
| sampler=sampler, | |
| num_workers=0, | |
| collate_fn=collate_for_eval, | |
| pin_memory=True | |
| ) | |
| # --- Evaluation Loop --- | |
| local_results = [] | |
| model.eval() | |
| progress_bar = tqdm(dataloader, desc="Evaluating on Librispeech", disable=(rank != 0)) | |
| for batch_idx, batch in enumerate(progress_bar): | |
| # if batch_idx > 1: | |
| # break | |
| if batch is None: | |
| continue | |
| input_ids = batch["input_ids"].to(device) | |
| gt_texts = batch["gt_texts"] | |
| sample_ids = batch["sample_ids"] | |
| with torch.no_grad(): | |
| output_ids = model.module.mmu_generate(input_ids, max_new_tokens=args.new_tok, steps=args.generation_step, block_length=args.new_tok, remasking=args.remasking) | |
| decoded_texts = text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) | |
| # print(decoded_texts) | |
| for i in range(len(decoded_texts)): | |
| local_results.append({ | |
| "sample_id": sample_ids[i], | |
| "gt_text": gt_texts[i], | |
| "decoded_text": decoded_texts[i] | |
| }) | |
| if rank == 0 and i == 0: | |
| logger.info(f" ID: {sample_ids[i]}") | |
| logger.info(f" GT: {data_utils.normalizer(gt_texts[i])}") | |
| logger.info(f" PD: {data_utils.normalizer(decoded_texts[i])}") | |
| # --- Gather Results from All GPUs --- | |
| all_results = [None] * world_size | |
| dist.all_gather_object(all_results, local_results) | |
| # --- Final Processing and Logging (only on rank 0) --- | |
| if rank == 0: | |
| logger.info("Gathering and processing results from all GPUs...") | |
| final_results = [item for sublist in all_results for item in sublist] | |
| groundtruth_text_list = [data_utils.normalizer(res["gt_text"]) for res in final_results] | |
| recognized_text_list = [data_utils.normalizer(res["decoded_text"]) for res in final_results] | |
| results_table = wandb.Table(columns=["ID", "Ground Truth", "Response"]) | |
| for res in final_results: | |
| results_table.add_data(res["sample_id"], res["gt_text"], res["decoded_text"]) | |
| wandb.log({"Speech-to-Text Response Examples": results_table}) | |
| wer, errors, words = calculate_WER(groundtruth_text_list, recognized_text_list) | |
| logger.info(f"Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") | |
| wandb.log({ | |
| "WER": wer, | |
| "Total Word Errors": errors, | |
| "Total Words": words | |
| }) | |
| # --- Cleanup --- | |
| if rank == 0: | |
| wandb.finish() | |
| cleanup_distributed() | |
| if __name__ == '__main__': | |
| main() |