import os import itertools import json from pathlib import Path from typing import Dict, Iterable, List, Tuple import torch import wandb from omegaconf import OmegaConf from transformers import AutoTokenizer from models import MAGVITv2, OMadaModelLM from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer from training.prompting_utils import UniversalPrompting def load_train_config(path: str): cfg = OmegaConf.load(path) return cfg def get_vq_model_image(cfg, device): vq_cfg = cfg.model.vq_model_image if getattr(vq_cfg, "pretrained_model_path", None): model = MAGVITv2().to(device) state_dict = torch.load(vq_cfg.pretrained_model_path)["model"] model.load_state_dict(state_dict) return model.eval() else: return MAGVITv2.from_pretrained(vq_cfg.vq_model_name).to(device).eval() def get_vq_model_audio(cfg, device): vq_cfg = cfg.model.vq_model_audio # Always EMOVA for now model = EMOVASpeechTokenizer.from_pretrained(vq_cfg.vq_model_name) model = model.to(device) model.eval() return model def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]: tokenizer = AutoTokenizer.from_pretrained(cfg.model.omada.tokenizer_path, padding_side="left") uni_prompting = UniversalPrompting( tokenizer, max_text_len=cfg.dataset.preprocessing.max_seq_length, max_audio_len=cfg.dataset.preprocessing.max_aud_length, special_tokens=( "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", "<|i2i|>", "<|v2s|>", "<|s2s|>", "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>", "<|ti2ti|>", "<|t2ti|>", ), ignore_id=-100, cond_dropout_prob=cfg.training.cond_dropout_prob, use_reserved_token=True, ) return uni_prompting, tokenizer def load_omada_from_checkpoint(ckpt_unwrapped_dir: str, device, torch_dtype=torch.bfloat16) -> OMadaModelLM: """Load OMada model weights from an `unwrapped_model` directory. The helper used to rely on a hard-coded config path which broke when evaluating checkpoints from other training steps. We now detect the config.json co-located with the weights so any checkpoint exported by the trainer can be used directly. """ ckpt_path = Path(ckpt_unwrapped_dir) if not ckpt_path.is_dir(): raise FileNotFoundError(f"Expected an 'unwrapped_model' directory, got {ckpt_unwrapped_dir}") config_path = ckpt_path / "config.json" config_arg = str(config_path) if config_path.exists() else None model = OMadaModelLM.from_pretrained( ckpt_unwrapped_dir, torch_dtype=torch_dtype, config=config_arg, trust_remote_code=True, ).to(device) model.eval() return model def list_checkpoints(ckpt_root: str) -> List[str]: """Return a sorted list of checkpoint 'unwrapped_model' dirs under a training output dir or a direct ckpt dir. Accepts either: - A path that already ends with 'unwrapped_model' - A path to 'checkpoint-XXXX' (we append 'unwrapped_model') - A path to the experiment output dir that contains many 'checkpoint-*' """ p = Path(ckpt_root) if p.name == "unwrapped_model" and p.is_dir(): return [str(p)] if p.name.startswith("checkpoint-") and p.is_dir(): inner = p / "unwrapped_model" return [str(inner)] if inner.is_dir() else [] # otherwise, collect children checkpoints outs = [] for child in p.iterdir(): if child.is_dir() and child.name.startswith("checkpoint-"): inner = child / "unwrapped_model" if inner.is_dir(): outs.append(str(inner)) # sort by numeric step if possible def step_key(s: str): try: return int(Path(s).parent.name.split("-")[-1]) except Exception: return -1 outs.sort(key=step_key) return outs def grid_dict(product_space: Dict[str, Iterable]) -> List[Dict]: """Expand a dict of lists to a list of dict combinations. Example: {a:[1,2], b:["x"]} -> [{a:1,b:"x"},{a:2,b:"x"}] """ keys = list(product_space.keys()) values = [list(v if isinstance(v, (list, tuple)) else [v]) for v in product_space.values()] combos = [] for vals in itertools.product(*values): combos.append({k: v for k, v in zip(keys, vals)}) return combos def init_wandb(infer_cfg: Dict, task: str, ckpt_path: str, hparams: Dict): wcfg = infer_cfg.get("wandb", {}) project = wcfg.get("project", f"omada-inference-{task}") entity = wcfg.get("entity") group = wcfg.get("group", f"{task}") name_prefix = wcfg.get("name_prefix", f"{task}") step_str = Path(ckpt_path).parent.name run_name = f"{name_prefix}-{step_str}-" + ",".join([f"{k}={v}" for k, v in hparams.items()]) tags = wcfg.get("tags", []) wandb.init(project=project, entity=entity, group=group, name=run_name, tags=tags, config={ "task": task, "checkpoint": ckpt_path, "hparams": hparams, }) def safe_log_table(name: str, columns: List[str], rows: List[List]): try: table = wandb.Table(columns=columns) for r in rows: table.add_data(*r) wandb.log({name: table}) except Exception: pass