Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import argparse | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from functools import partial | |
| from typing import Callable, List | |
| import re | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from datasets import load_dataset | |
| import wandb | |
| from omegaconf import OmegaConf | |
| from training.data import S2T_INSTRUCTION | |
| from inference.common import ( | |
| load_train_config, | |
| get_vq_model_audio, | |
| build_uni_prompting, | |
| load_omada_from_checkpoint, | |
| list_checkpoints, | |
| grid_dict, | |
| init_wandb, | |
| safe_log_table, | |
| ) | |
| _ANGLE_TOKEN_RE = re.compile(r"<[^>]+>") | |
| _EXCLAMATIONPOINT_RE = re.compile(r"EXCLAMATIONPOINT", flags=re.IGNORECASE) | |
| _PUNCT_RE = re.compile(r"[^\w\s']") | |
| def _strip_custom_markers(text: str) -> str: | |
| had_exclamationpoint = bool(_EXCLAMATIONPOINT_RE.search(text)) | |
| text = _ANGLE_TOKEN_RE.sub(" ", text) | |
| if had_exclamationpoint: | |
| text = _EXCLAMATIONPOINT_RE.sub(" ", text) | |
| if had_exclamationpoint: | |
| text = text.replace(".", "") | |
| text = _PUNCT_RE.sub(" ", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| def _basic_normalize(text: str) -> str: | |
| text = _strip_custom_markers(text) | |
| text = text.lower() | |
| text = re.sub(r"[^\w\s']", "", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| def build_normalize_fn(mode: str) -> Callable[[str], str]: | |
| mode = (mode or "basic").strip().lower() | |
| if mode in {"off", "none", "no"}: | |
| return lambda s: s | |
| if mode in {"english", "whisper", "whisper_en"}: | |
| try: | |
| from normalizer.normalizer import EnglishTextNormalizer | |
| n = EnglishTextNormalizer() | |
| def _fn(s: str) -> str: | |
| return re.sub(r"\s+", " ", n(s)).strip() | |
| return _fn | |
| except Exception: | |
| # Fallback to basic if normalizer package import fails | |
| return _basic_normalize | |
| # default basic | |
| return _basic_normalize | |
| def calculate_wer(predictions: List[str], references: List[str], normalize: Callable[[str], str] = _basic_normalize): | |
| import editdistance | |
| # Normalize texts before WER | |
| predictions = [normalize(p) for p in predictions] | |
| references = [normalize(r) for r in references] | |
| total_errors = 0 | |
| total_words = 0 | |
| for pred, ref in zip(predictions, references): | |
| pred_words = pred.split() | |
| ref_words = ref.split() | |
| total_errors += editdistance.eval(pred_words, ref_words) | |
| total_words += len(ref_words) | |
| wer = total_errors / total_words if total_words > 0 else 0.0 | |
| return wer, total_errors, total_words | |
| class S2TEvalDataset(Dataset): | |
| def __init__(self, hf_dataset, root_path: str): | |
| self.hf_dataset = hf_dataset | |
| self.root_path = root_path | |
| def __len__(self): | |
| return len(self.hf_dataset) | |
| def __getitem__(self, idx): | |
| ex = self.hf_dataset[idx] | |
| sample_id = ex["id"] | |
| speaker_id, chapter_id, _ = sample_id.split("-") | |
| audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") | |
| return {"audio_path": audio_path, "gt_text": ex["text"], "sample_id": sample_id} | |
| def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, cfg): | |
| import random | |
| audio_tokens_batch = [] | |
| offset = len(uni_prompting.text_tokenizer) + cfg.model.omada.codebook_size | |
| for item in batch: | |
| path = item['audio_path'] | |
| tokens = vq_model_audio.encode(path) | |
| tokens_with_offset = tokens + offset | |
| audio_tokens_batch.append(tokens_with_offset) | |
| sptids = uni_prompting.sptids_dict | |
| device = audio_tokens_batch[0].device | |
| batched_input_ids = [] | |
| for audio_tokens in audio_tokens_batch: | |
| task_tensor = sptids['<|s2t|>'].to(device).unsqueeze(0) | |
| soa_tensor = sptids['<|soa|>'].to(device).unsqueeze(0) | |
| eoa_tensor = sptids['<|eoa|>'].to(device).unsqueeze(0) | |
| audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) | |
| prompt_text = random.choice(S2T_INSTRUCTION) | |
| full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' | |
| prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) | |
| final_seq = torch.cat([audio_block, prompt_tensor], dim=1) | |
| batched_input_ids.append(final_seq.squeeze(0)) | |
| max_len = max(seq.size(0) for seq in batched_input_ids) | |
| pad_token_id = 126093 | |
| final_batch_input_ids = torch.full( | |
| (len(batched_input_ids), max_len), | |
| pad_token_id, | |
| dtype=torch.long, | |
| device=device, | |
| ) | |
| for i, seq in enumerate(batched_input_ids): | |
| final_batch_input_ids[i, -len(seq):] = seq | |
| return { | |
| "input_ids": final_batch_input_ids, | |
| "gt_texts": [item['gt_text'] for item in batch], | |
| "sample_ids": [item['sample_id'] for item in batch], | |
| } | |
| def run_once(ckpt_path: str, hparams: dict, train_cfg, device): | |
| # Models and prompting | |
| uni_prompting, tokenizer = build_uni_prompting(train_cfg) | |
| vq_audio = get_vq_model_audio(train_cfg, device) | |
| model = load_omada_from_checkpoint(ckpt_path, device) | |
| # Dataset | |
| dcfg = hparams.get("dataset", {}) | |
| subset = dcfg.get("subset", "clean") | |
| split = dcfg.get("split", "test") | |
| limit = int(dcfg.get("limit", 128)) | |
| root_path = dcfg.get("root_path", "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") | |
| ds_raw = load_dataset("librispeech_asr", subset, split=split) | |
| if limit > 0: | |
| ds_raw = ds_raw.select(range(min(limit, len(ds_raw)))) | |
| ds = S2TEvalDataset(ds_raw, root_path=root_path) | |
| collate = partial( | |
| s2t_eval_collate_fn, | |
| vq_model_audio=vq_audio, | |
| tokenizer=uni_prompting.text_tokenizer, | |
| uni_prompting=uni_prompting, | |
| cfg=train_cfg, | |
| ) | |
| batch_size = int(hparams.get("batch_size", train_cfg.training.batch_size_s2t)) | |
| loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate) | |
| # Generation hparams | |
| steps = int(hparams.get("steps", 128)) | |
| block_length = int(hparams.get("block_length", 64)) | |
| max_new_tokens = int(hparams.get("max_new_tokens", 256)) | |
| remasking = hparams.get("remasking", "low_confidence") | |
| # W&B | |
| init_wandb(hparams.get("_infer_cfg", {}), "s2t", ckpt_path, { | |
| "steps": steps, | |
| "block_length": block_length, | |
| "max_new_tokens": max_new_tokens, | |
| "remasking": remasking, | |
| "batch_size": batch_size, | |
| }) | |
| preds, refs, rows = [], [], [] | |
| norm_mode = str(hparams.get("text_norm", "basic")) | |
| normalize_fn = build_normalize_fn(norm_mode) | |
| for batch in loader: | |
| input_ids = batch["input_ids"].to(device) | |
| gt_texts = batch["gt_texts"] | |
| sample_ids = batch["sample_ids"] | |
| with torch.no_grad(): | |
| output_ids = model.mmu_generate( | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| steps=steps, | |
| block_length=block_length, | |
| remasking=remasking, | |
| ) | |
| decoded = uni_prompting.text_tokenizer.batch_decode( | |
| output_ids[:, input_ids.shape[1]:], skip_special_tokens=True | |
| ) | |
| # print(decoded) | |
| clean_gts = [_strip_custom_markers(gt) for gt in gt_texts] | |
| clean_preds = [_strip_custom_markers(pred) for pred in decoded] | |
| print(clean_preds) | |
| for sid, clean_gt, clean_pred in zip(sample_ids, clean_gts, clean_preds): | |
| refs.append(clean_gt) | |
| preds.append(clean_pred) | |
| rows.append([sid, clean_gt, clean_pred]) | |
| wer, errors, words = calculate_wer(preds, refs, normalize=normalize_fn) | |
| wandb.log({ | |
| "metrics/s2t_wer": wer, | |
| "metrics/s2t_word_errors": errors, | |
| "metrics/s2t_total_words": words, | |
| }) | |
| safe_log_table("samples/s2t", ["ID", "GT", "PRED"], rows[:64]) | |
| wandb.finish() | |
| def main(): | |
| parser = argparse.ArgumentParser(description="S2T Inference with CLI overrides or config grids") | |
| parser.add_argument("--train_config", required=True, help="Path to training YAML used to build tokenizers and VQ models") | |
| parser.add_argument("--ckpt_root", required=True, help="Experiment output dir or a specific checkpoint path") | |
| parser.add_argument("--infer_config", required=False, help="Optional YAML for W&B and grids") | |
| parser.add_argument("--checkpoint", action="append", help="Repeatable: explicit checkpoint path(s). Can be '.../unwrapped_model', '.../checkpoint-XXXX', or experiment dir") | |
| # Generation overrides | |
| parser.add_argument("--steps", type=int) | |
| parser.add_argument("--block_length", type=int) | |
| parser.add_argument("--max_new_tokens", type=int) | |
| parser.add_argument("--remasking") | |
| parser.add_argument("--batch_size", type=int) | |
| parser.add_argument("--text_norm", choices=["off", "basic", "english", "whisper", "whisper_en"], help="Text normalization for WER") | |
| # Dataset overrides | |
| parser.add_argument("--subset") | |
| parser.add_argument("--split") | |
| parser.add_argument("--root_path") | |
| parser.add_argument("--limit", type=int) | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| train_cfg = load_train_config(args.train_config) | |
| infer_cfg = {} | |
| if args.infer_config: | |
| infer_cfg = OmegaConf.to_container(OmegaConf.load(args.infer_config), resolve=True) | |
| # Checkpoints | |
| # Build checkpoint list: --checkpoint > infer_config.checkpoints > --ckpt_root | |
| if args.checkpoint: | |
| ckpt_list = [] | |
| for p in args.checkpoint: | |
| ckpt_list.extend(list_checkpoints(p)) | |
| else: | |
| ckpts = infer_cfg.get("checkpoints") if infer_cfg else None | |
| if ckpts: | |
| ckpt_list = [] | |
| for p in ckpts: | |
| ckpt_list.extend(list_checkpoints(p)) | |
| else: | |
| ckpt_list = list_checkpoints(args.ckpt_root) | |
| if not ckpt_list: | |
| raise FileNotFoundError(f"No checkpoints found under {args.ckpt_root} or in infer config.") | |
| override_present = any([ | |
| args.steps is not None, args.block_length is not None, args.max_new_tokens is not None, | |
| args.remasking is not None, args.batch_size is not None, | |
| args.text_norm is not None, | |
| args.subset is not None, args.split is not None, args.root_path is not None, args.limit is not None, | |
| ]) | |
| if override_present or not infer_cfg: | |
| single = { | |
| "steps": args.steps if args.steps is not None else 128, | |
| "block_length": args.block_length if args.block_length is not None else 64, | |
| "max_new_tokens": args.max_new_tokens if args.max_new_tokens is not None else 256, | |
| "remasking": args.remasking if args.remasking is not None else "low_confidence", | |
| "batch_size": args.batch_size if args.batch_size is not None else int(train_cfg.training.batch_size_s2t), | |
| } | |
| if args.text_norm is not None: | |
| single["text_norm"] = args.text_norm | |
| dcfg = { | |
| "subset": args.subset or "clean", | |
| "split": args.split or "test", | |
| "root_path": args.root_path or "/home/work/AIDAS/data/audio/LibriSpeech/test-clean", | |
| "limit": args.limit if args.limit is not None else 128, | |
| } | |
| single["dataset"] = dcfg | |
| single["_infer_cfg"] = infer_cfg | |
| combos = [single] | |
| else: | |
| gen_grid = infer_cfg.get("generation", { | |
| "steps": [128], | |
| "block_length": [64], | |
| "max_new_tokens": [256], | |
| "remasking": ["low_confidence"], | |
| "batch_size": [int(train_cfg.training.batch_size_s2t)], | |
| }) | |
| combos = grid_dict(gen_grid) | |
| dcfg = infer_cfg.get("dataset", { | |
| "subset": "clean", | |
| "split": "test", | |
| "root_path": "/home/work/AIDAS/data/audio/LibriSpeech/test-clean", | |
| "limit": 128, | |
| }) | |
| # Apply overrides if provided | |
| if args.subset is not None: | |
| dcfg["subset"] = args.subset | |
| if args.split is not None: | |
| dcfg["split"] = args.split | |
| if args.root_path is not None: | |
| dcfg["root_path"] = args.root_path | |
| if args.limit is not None: | |
| dcfg["limit"] = args.limit | |
| for c in combos: | |
| if args.steps is not None: | |
| c["steps"] = args.steps | |
| if args.block_length is not None: | |
| c["block_length"] = args.block_length | |
| if args.max_new_tokens is not None: | |
| c["max_new_tokens"] = args.max_new_tokens | |
| if args.remasking is not None: | |
| c["remasking"] = args.remasking | |
| if args.batch_size is not None: | |
| c["batch_size"] = args.batch_size | |
| if args.text_norm is not None: | |
| c["text_norm"] = args.text_norm | |
| c["dataset"] = dcfg | |
| c["_infer_cfg"] = infer_cfg | |
| for ckpt in ckpt_list: | |
| for hp in combos: | |
| run_once(ckpt, hp, train_cfg, device) | |
| if __name__ == "__main__": | |
| main() | |