jaeikkim
Reinit Space without binary assets
7bfbdc3
# coding=utf-8
# Copyright 2025 AIDAS Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import logging
import editdistance
from functools import partial
os.environ["TOKENIZERS_PARALLETISM"] = "true"
from tqdm import tqdm
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import wandb
from datasets import load_dataset
from transformers import AutoModel, AutoProcessor
# --- Helper Functions (from your reference script) ---
def setup_logger(rank):
"""Sets up a logger for each DDP process."""
logger = logging.getLogger(__name__)
if logger.hasHandlers():
logger.handlers.clear()
formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel(logging.INFO if rank == 0 else logging.WARNING)
return logger
def calculate_WER(recognized_text_list, groundtruth_text_list):
"""Calculates the Word Error Rate (WER) between predicted and ground truth texts."""
word_num, scores = 0.0, 0.0
for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list):
recognized_text = re.sub(r"[^\w\s']", "", recognized_text.lower())
groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text.lower())
recognized_word_list = recognized_text.split()
groundtruth_word_list = groundtruth_text.split()
current_word_num = len(groundtruth_word_list)
word_num += current_word_num
scores += editdistance.eval(recognized_word_list, groundtruth_word_list)
WER = scores / word_num if word_num > 0 else 0.0
return WER, scores, word_num
def get_librispeech_dataset(logger, split="test.clean"):
"""Loads the Librispeech ASR dataset from Hugging Face."""
logger.info(f"Loading librispeech_asr dataset ({split})...")
dataset = load_dataset("librispeech_asr", split=split, trust_remote_code=True)
logger.info("Dataset loaded successfully.")
return dataset
def setup_distributed(rank, world_size):
"""Initializes the distributed process group."""
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup_distributed():
"""Cleans up the distributed process group."""
dist.destroy_process_group()
# --- Custom Dataset and Collate Function for EMOVA ---
class LibrispeechAudioDataset(Dataset):
"""A simple dataset that returns audio file path and ground truth text."""
def __init__(self, hf_dataset):
self.hf_dataset = hf_dataset
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
example = self.hf_dataset[idx]
return {
"audio_path": example['file'],
"gt_text": example['text'],
"sample_id": example['id']
}
class EmovaS2TCollateFn:
"""
Collate function to prepare batches for the EMOVA model using its processor.
"""
def __init__(self, processor):
self.processor = processor
self.prompt_text = "Transcribe the given audio."
def __call__(self, batch):
audio_paths = [item["audio_path"] for item in batch]
gt_texts = [item["gt_text"] for item in batch]
sample_ids = [item["sample_id"] for item in batch]
# Construct the text input for each audio file in the batch
text_inputs = [
[
{"role": "user", "content": [{"type": "audio"}, {"type": "text", "text": self.prompt_text}]}
]
for _ in audio_paths
]
# Use the EMOVA processor to prepare the multimodal batch
inputs = self.processor(
text=text_inputs,
audios=audio_paths,
return_tensors="pt",
padding=True
)
inputs['gt_texts'] = gt_texts
inputs['sample_ids'] = sample_ids
return inputs
def main():
"""Main function to run the distributed evaluation."""
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
setup_distributed(rank, world_size)
device = torch.device(f"cuda:{rank}")
logger = setup_logger(rank)
if rank == 0:
wandb.init(project="emova-librispeech-eval")
# --- 1. Load EMOVA Models and Processors ---
logger.info("Loading EMOVA models and processors...")
model_name = "Emova-ollm/emova-qwen-2-5-7b-hf"
model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2',
low_cpu_mem_usage=True,
trust_remote_code=True
).to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
speech_tokenizer = AutoModel.from_pretrained(
"Emova-ollm/emova_speech_tokenizer_hf",
torch_dtype=torch.float32,
trust_remote_code=True
).to(device).eval()
processor.set_speech_tokenizer(speech_tokenizer)
# Wrap the main model with DDP
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
logger.info("✅ Models loaded and wrapped with DDP successfully!")
# --- 2. Setup DataLoader ---
hf_dataset = get_librispeech_dataset(logger, split="test.clean")
eval_dataset = LibrispeechAudioDataset(hf_dataset)
sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False)
collate_fn = EmovaS2TCollateFn(processor)
dataloader = DataLoader(
eval_dataset,
batch_size=4, # Adjust batch size based on your GPU memory
sampler=sampler,
num_workers=4,
collate_fn=collate_fn,
pin_memory=True
)
# --- 3. Evaluation Loop ---
local_results = []
model.eval()
progress_bar = tqdm(dataloader, desc="Evaluating on Librispeech", disable=(rank != 0))
for batch in progress_bar:
gt_texts = batch.pop("gt_texts")
sample_ids = batch.pop("sample_ids")
# Move batch tensors to the correct device
inputs = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model.module.generate(**inputs, max_new_tokens=256, do_sample=False)
# Slice to get only the generated tokens
generated_ids = outputs[:, inputs['input_ids'].shape[1]:]
decoded_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
for i in range(len(decoded_texts)):
local_results.append({
"sample_id": sample_ids[i],
"gt_text": gt_texts[i],
"decoded_text": decoded_texts[i].strip()
})
if rank == 0 and i == 0 and len(local_results) % 10 == 1: # Log sample every 10 batches on rank 0
logger.info(f"\n--- Sample ---")
logger.info(f" ID: {sample_ids[i]}")
logger.info(f" GT: {gt_texts[i]}")
logger.info(f" PD: {decoded_texts[i].strip()}")
logger.info(f"----------------")
# --- 4. Gather Results and Calculate Final Score ---
all_results = [None] * world_size
dist.all_gather_object(all_results, local_results)
if rank == 0:
logger.info("Gathering and processing results from all GPUs...")
final_results = [item for sublist in all_results for item in sublist]
gt_list = [res["gt_text"] for res in final_results]
pred_list = [res["decoded_text"] for res in final_results]
results_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"])
for res in final_results:
results_table.add_data(res["sample_id"], res["gt_text"], res["decoded_text"])
wandb.log({"S2T Predictions": results_table})
wer, errors, words = calculate_WER(pred_list, gt_list)
logger.info(f"Final WER (Librispeech test.clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}")
wandb.log({"WER": wer, "Total Word Errors": errors, "Total Words": words})
# --- Cleanup ---
if rank == 0:
wandb.finish()
cleanup_distributed()
if __name__ == '__main__':
# Set master address and port for DDP
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '12355'
main()