jaeikkim's picture
Cleanup binaries before space push
e80840a
import os
import itertools
import json
from pathlib import Path
from typing import Dict, Iterable, List, Tuple
import torch
import wandb
from omegaconf import OmegaConf
from transformers import AutoTokenizer
from models import MAGVITv2, OMadaModelLM
from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer
from training.prompting_utils import UniversalPrompting
def load_train_config(path: str):
cfg = OmegaConf.load(path)
return cfg
def get_vq_model_image(cfg, device):
vq_cfg = cfg.model.vq_model_image
if getattr(vq_cfg, "pretrained_model_path", None):
model = MAGVITv2().to(device)
state_dict = torch.load(vq_cfg.pretrained_model_path)["model"]
model.load_state_dict(state_dict)
return model.eval()
else:
return MAGVITv2.from_pretrained(vq_cfg.vq_model_name).to(device).eval()
def get_vq_model_audio(cfg, device):
vq_cfg = cfg.model.vq_model_audio
# Always EMOVA for now
model = EMOVASpeechTokenizer.from_pretrained(vq_cfg.vq_model_name)
model = model.to(device)
model.eval()
return model
def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model.omada.tokenizer_path, padding_side="left")
uni_prompting = UniversalPrompting(
tokenizer,
max_text_len=cfg.dataset.preprocessing.max_seq_length,
max_audio_len=cfg.dataset.preprocessing.max_aud_length,
special_tokens=(
"<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>",
"<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>",
"<|i2i|>", "<|v2s|>", "<|s2s|>",
"<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>",
"<|ti2ti|>", "<|t2ti|>",
),
ignore_id=-100,
cond_dropout_prob=cfg.training.cond_dropout_prob,
use_reserved_token=True,
)
return uni_prompting, tokenizer
def load_omada_from_checkpoint(ckpt_unwrapped_dir: str, device, torch_dtype=torch.bfloat16) -> OMadaModelLM:
"""Load OMada model weights from an `unwrapped_model` directory.
The helper used to rely on a hard-coded config path which broke when
evaluating checkpoints from other training steps. We now detect the
config.json co-located with the weights so any checkpoint exported by the
trainer can be used directly.
"""
ckpt_path = Path(ckpt_unwrapped_dir)
if not ckpt_path.is_dir():
raise FileNotFoundError(f"Expected an 'unwrapped_model' directory, got {ckpt_unwrapped_dir}")
config_path = ckpt_path / "config.json"
config_arg = str(config_path) if config_path.exists() else None
model = OMadaModelLM.from_pretrained(
ckpt_unwrapped_dir,
torch_dtype=torch_dtype,
config=config_arg,
trust_remote_code=True,
).to(device)
model.eval()
return model
def list_checkpoints(ckpt_root: str) -> List[str]:
"""Return a sorted list of checkpoint 'unwrapped_model' dirs under a training output dir or a direct ckpt dir.
Accepts either:
- A path that already ends with 'unwrapped_model'
- A path to 'checkpoint-XXXX' (we append 'unwrapped_model')
- A path to the experiment output dir that contains many 'checkpoint-*'
"""
p = Path(ckpt_root)
if p.name == "unwrapped_model" and p.is_dir():
return [str(p)]
if p.name.startswith("checkpoint-") and p.is_dir():
inner = p / "unwrapped_model"
return [str(inner)] if inner.is_dir() else []
# otherwise, collect children checkpoints
outs = []
for child in p.iterdir():
if child.is_dir() and child.name.startswith("checkpoint-"):
inner = child / "unwrapped_model"
if inner.is_dir():
outs.append(str(inner))
# sort by numeric step if possible
def step_key(s: str):
try:
return int(Path(s).parent.name.split("-")[-1])
except Exception:
return -1
outs.sort(key=step_key)
return outs
def grid_dict(product_space: Dict[str, Iterable]) -> List[Dict]:
"""Expand a dict of lists to a list of dict combinations.
Example: {a:[1,2], b:["x"]} -> [{a:1,b:"x"},{a:2,b:"x"}]
"""
keys = list(product_space.keys())
values = [list(v if isinstance(v, (list, tuple)) else [v]) for v in product_space.values()]
combos = []
for vals in itertools.product(*values):
combos.append({k: v for k, v in zip(keys, vals)})
return combos
def init_wandb(infer_cfg: Dict, task: str, ckpt_path: str, hparams: Dict):
wcfg = infer_cfg.get("wandb", {})
project = wcfg.get("project", f"omada-inference-{task}")
entity = wcfg.get("entity")
group = wcfg.get("group", f"{task}")
name_prefix = wcfg.get("name_prefix", f"{task}")
step_str = Path(ckpt_path).parent.name
run_name = f"{name_prefix}-{step_str}-" + ",".join([f"{k}={v}" for k, v in hparams.items()])
tags = wcfg.get("tags", [])
wandb.init(project=project, entity=entity, group=group, name=run_name, tags=tags, config={
"task": task,
"checkpoint": ckpt_path,
"hparams": hparams,
})
def safe_log_table(name: str, columns: List[str], rows: List[List]):
try:
table = wandb.Table(columns=columns)
for r in rows:
table.add_data(*r)
wandb.log({name: table})
except Exception:
pass