Spaces:
Running
on
Zero
Running
on
Zero
| # coding=utf-8 | |
| # Copyright 2025 AIDAS Lab | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import re | |
| import logging | |
| import editdistance | |
| from functools import partial | |
| os.environ["TOKENIZERS_PARALLETISM"] = "true" | |
| from tqdm import tqdm | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import wandb | |
| from datasets import load_dataset | |
| from transformers import AutoModel, AutoProcessor | |
| # --- Helper Functions (from your reference script) --- | |
| def setup_logger(rank): | |
| """Sets up a logger for each DDP process.""" | |
| logger = logging.getLogger(__name__) | |
| if logger.hasHandlers(): | |
| logger.handlers.clear() | |
| formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | |
| ch = logging.StreamHandler() | |
| ch.setFormatter(formatter) | |
| logger.addHandler(ch) | |
| logger.setLevel(logging.INFO if rank == 0 else logging.WARNING) | |
| return logger | |
| def calculate_WER(recognized_text_list, groundtruth_text_list): | |
| """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" | |
| word_num, scores = 0.0, 0.0 | |
| for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list): | |
| recognized_text = re.sub(r"[^\w\s']", "", recognized_text.lower()) | |
| groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text.lower()) | |
| recognized_word_list = recognized_text.split() | |
| groundtruth_word_list = groundtruth_text.split() | |
| current_word_num = len(groundtruth_word_list) | |
| word_num += current_word_num | |
| scores += editdistance.eval(recognized_word_list, groundtruth_word_list) | |
| WER = scores / word_num if word_num > 0 else 0.0 | |
| return WER, scores, word_num | |
| def get_librispeech_dataset(logger, split="test.clean"): | |
| """Loads the Librispeech ASR dataset from Hugging Face.""" | |
| logger.info(f"Loading librispeech_asr dataset ({split})...") | |
| dataset = load_dataset("librispeech_asr", split=split, trust_remote_code=True) | |
| logger.info("Dataset loaded successfully.") | |
| return dataset | |
| def setup_distributed(rank, world_size): | |
| """Initializes the distributed process group.""" | |
| dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
| def cleanup_distributed(): | |
| """Cleans up the distributed process group.""" | |
| dist.destroy_process_group() | |
| # --- Custom Dataset and Collate Function for EMOVA --- | |
| class LibrispeechAudioDataset(Dataset): | |
| """A simple dataset that returns audio file path and ground truth text.""" | |
| def __init__(self, hf_dataset): | |
| self.hf_dataset = hf_dataset | |
| def __len__(self): | |
| return len(self.hf_dataset) | |
| def __getitem__(self, idx): | |
| example = self.hf_dataset[idx] | |
| return { | |
| "audio_path": example['file'], | |
| "gt_text": example['text'], | |
| "sample_id": example['id'] | |
| } | |
| class EmovaS2TCollateFn: | |
| """ | |
| Collate function to prepare batches for the EMOVA model using its processor. | |
| """ | |
| def __init__(self, processor): | |
| self.processor = processor | |
| self.prompt_text = "Transcribe the given audio." | |
| def __call__(self, batch): | |
| audio_paths = [item["audio_path"] for item in batch] | |
| gt_texts = [item["gt_text"] for item in batch] | |
| sample_ids = [item["sample_id"] for item in batch] | |
| # Construct the text input for each audio file in the batch | |
| text_inputs = [ | |
| [ | |
| {"role": "user", "content": [{"type": "audio"}, {"type": "text", "text": self.prompt_text}]} | |
| ] | |
| for _ in audio_paths | |
| ] | |
| # Use the EMOVA processor to prepare the multimodal batch | |
| inputs = self.processor( | |
| text=text_inputs, | |
| audios=audio_paths, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| inputs['gt_texts'] = gt_texts | |
| inputs['sample_ids'] = sample_ids | |
| return inputs | |
| def main(): | |
| """Main function to run the distributed evaluation.""" | |
| rank = int(os.environ['RANK']) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| setup_distributed(rank, world_size) | |
| device = torch.device(f"cuda:{rank}") | |
| logger = setup_logger(rank) | |
| if rank == 0: | |
| wandb.init(project="emova-librispeech-eval") | |
| # --- 1. Load EMOVA Models and Processors --- | |
| logger.info("Loading EMOVA models and processors...") | |
| model_name = "Emova-ollm/emova-qwen-2-5-7b-hf" | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation='flash_attention_2', | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ).to(device) | |
| processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) | |
| speech_tokenizer = AutoModel.from_pretrained( | |
| "Emova-ollm/emova_speech_tokenizer_hf", | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True | |
| ).to(device).eval() | |
| processor.set_speech_tokenizer(speech_tokenizer) | |
| # Wrap the main model with DDP | |
| model = DDP(model, device_ids=[rank], find_unused_parameters=True) | |
| logger.info("✅ Models loaded and wrapped with DDP successfully!") | |
| # --- 2. Setup DataLoader --- | |
| hf_dataset = get_librispeech_dataset(logger, split="test.clean") | |
| eval_dataset = LibrispeechAudioDataset(hf_dataset) | |
| sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) | |
| collate_fn = EmovaS2TCollateFn(processor) | |
| dataloader = DataLoader( | |
| eval_dataset, | |
| batch_size=4, # Adjust batch size based on your GPU memory | |
| sampler=sampler, | |
| num_workers=4, | |
| collate_fn=collate_fn, | |
| pin_memory=True | |
| ) | |
| # --- 3. Evaluation Loop --- | |
| local_results = [] | |
| model.eval() | |
| progress_bar = tqdm(dataloader, desc="Evaluating on Librispeech", disable=(rank != 0)) | |
| for batch in progress_bar: | |
| gt_texts = batch.pop("gt_texts") | |
| sample_ids = batch.pop("sample_ids") | |
| # Move batch tensors to the correct device | |
| inputs = {k: v.to(device) for k, v in batch.items()} | |
| with torch.no_grad(): | |
| outputs = model.module.generate(**inputs, max_new_tokens=256, do_sample=False) | |
| # Slice to get only the generated tokens | |
| generated_ids = outputs[:, inputs['input_ids'].shape[1]:] | |
| decoded_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| for i in range(len(decoded_texts)): | |
| local_results.append({ | |
| "sample_id": sample_ids[i], | |
| "gt_text": gt_texts[i], | |
| "decoded_text": decoded_texts[i].strip() | |
| }) | |
| if rank == 0 and i == 0 and len(local_results) % 10 == 1: # Log sample every 10 batches on rank 0 | |
| logger.info(f"\n--- Sample ---") | |
| logger.info(f" ID: {sample_ids[i]}") | |
| logger.info(f" GT: {gt_texts[i]}") | |
| logger.info(f" PD: {decoded_texts[i].strip()}") | |
| logger.info(f"----------------") | |
| # --- 4. Gather Results and Calculate Final Score --- | |
| all_results = [None] * world_size | |
| dist.all_gather_object(all_results, local_results) | |
| if rank == 0: | |
| logger.info("Gathering and processing results from all GPUs...") | |
| final_results = [item for sublist in all_results for item in sublist] | |
| gt_list = [res["gt_text"] for res in final_results] | |
| pred_list = [res["decoded_text"] for res in final_results] | |
| results_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) | |
| for res in final_results: | |
| results_table.add_data(res["sample_id"], res["gt_text"], res["decoded_text"]) | |
| wandb.log({"S2T Predictions": results_table}) | |
| wer, errors, words = calculate_WER(pred_list, gt_list) | |
| logger.info(f"Final WER (Librispeech test.clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") | |
| wandb.log({"WER": wer, "Total Word Errors": errors, "Total Words": words}) | |
| # --- Cleanup --- | |
| if rank == 0: | |
| wandb.finish() | |
| cleanup_distributed() | |
| if __name__ == '__main__': | |
| # Set master address and port for DDP | |
| # os.environ['MASTER_ADDR'] = 'localhost' | |
| # os.environ['MASTER_PORT'] = '12355' | |
| main() |