jaeikkim
Reinit Space without binary assets
7bfbdc3
#!/usr/bin/env bash
set -euo pipefail
# Editable defaults (can be overridden by KEY=VALUE pairs or by passing CLI flags)
TRAIN_CONFIG=${TRAIN_CONFIG:-MMaDA/configs/omada_pretraining_stage1-3.yaml}
CKPT_ROOT=${CKPT_ROOT:-ckpts/omada/omada-training-stage1_3rd}
INFER_CONFIG=${INFER_CONFIG:-}
CHECKPOINTS=${CHECKPOINTS:-}
# Generation params
MODE=${MODE:-}
GUIDANCE_SCALE=${GUIDANCE_SCALE:-}
TEMPERATURE=${TEMPERATURE:-}
TIMESTEPS=${TIMESTEPS:-}
SEQ_LEN=${SEQ_LEN:-}
NOISE_SCHEDULE=${NOISE_SCHEDULE:-}
NOISE_TYPE=${NOISE_TYPE:-}
BATCH_SIZE=${BATCH_SIZE:-}
OUTPUT_DIR=${OUTPUT_DIR:-}
WER_ASR_MODEL=${WER_ASR_MODEL:-}
WER_LANGUAGE=${WER_LANGUAGE:-}
WER_MAX_SAMPLES=${WER_MAX_SAMPLES:-}
TEXT_NORM=${TEXT_NORM:-}
BLOCK_LENGTH=${BLOCK_LENGTH:-}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-}
AUDIO_CODEBOOK_SIZE=${AUDIO_CODEBOOK_SIZE:-}
# Aliases for convenience
CFG=${CFG:-}
TIMESTEP=${TIMESTEP:-}
WER=${WER:-}
ASR=${ASR:-}
LANG=${LANG:-}
WER_SAMPLES=${WER_SAMPLES:-}
# Dataset params
SUBSET=${SUBSET:-}
SPLIT=${SPLIT:-}
LIMIT=${LIMIT:-}
REST_ARGS=()
for arg in "$@"; do
case "$arg" in
*=*) eval "$arg" ;;
*) REST_ARGS+=("$arg") ;;
esac
done
# Apply alias variables if provided
[[ -n "$CFG" ]] && TRAIN_CONFIG="$CFG"
[[ -n "$TIMESTEP" ]] && TIMESTEPS="$TIMESTEP"
# Enable WER with defaults when WER toggle is set
if [[ -n "$WER" && "$WER" != "0" ]]; then
: "${WER_ASR_MODEL:=${ASR:-openai/whisper-large-v3}}"
: "${WER_LANGUAGE:=${LANG:-english}}"
: "${WER_MAX_SAMPLES:=${WER_SAMPLES:-64}}"
fi
ARGS=(
--train_config "$TRAIN_CONFIG"
--ckpt_root "$CKPT_ROOT"
)
if [[ -n "$INFER_CONFIG" ]]; then
ARGS+=(--infer_config "$INFER_CONFIG")
fi
if [[ -n "$CHECKPOINTS" ]]; then
IFS=',' read -r -a CK_ARR <<< "$CHECKPOINTS"
for c in "${CK_ARR[@]}"; do
ARGS+=(--checkpoint "$c")
done
fi
[[ -n "$MODE" ]] && ARGS+=(--mode "$MODE")
[[ -n "$GUIDANCE_SCALE" ]] && ARGS+=(--guidance_scale "$GUIDANCE_SCALE")
[[ -n "$TEMPERATURE" ]] && ARGS+=(--temperature "$TEMPERATURE")
[[ -n "$TIMESTEPS" ]] && ARGS+=(--timesteps "$TIMESTEPS")
[[ -n "$SEQ_LEN" ]] && ARGS+=(--seq_len "$SEQ_LEN")
[[ -n "$NOISE_SCHEDULE" ]] && ARGS+=(--noise_schedule "$NOISE_SCHEDULE")
[[ -n "$NOISE_TYPE" ]] && ARGS+=(--noise_type "$NOISE_TYPE")
[[ -n "$BATCH_SIZE" ]] && ARGS+=(--batch_size "$BATCH_SIZE")
[[ -n "$OUTPUT_DIR" ]] && ARGS+=(--output_dir "$OUTPUT_DIR")
[[ -n "$WER_ASR_MODEL" ]] && ARGS+=(--wer_asr_model "$WER_ASR_MODEL")
[[ -n "$WER_LANGUAGE" ]] && ARGS+=(--wer_language "$WER_LANGUAGE")
[[ -n "$WER_MAX_SAMPLES" ]] && ARGS+=(--wer_max_samples "$WER_MAX_SAMPLES")
[[ -n "$TEXT_NORM" ]] && ARGS+=(--text_norm "$TEXT_NORM")
[[ -n "$BLOCK_LENGTH" ]] && ARGS+=(--block_length "$BLOCK_LENGTH")
[[ -n "$MAX_NEW_TOKENS" ]] && ARGS+=(--max_new_tokens "$MAX_NEW_TOKENS")
[[ -n "$AUDIO_CODEBOOK_SIZE" ]] && ARGS+=(--audio_codebook_size "$AUDIO_CODEBOOK_SIZE")
[[ -n "$SUBSET" ]] && ARGS+=(--subset "$SUBSET")
[[ -n "$SPLIT" ]] && ARGS+=(--split "$SPLIT")
[[ -n "$LIMIT" ]] && ARGS+=(--limit "$LIMIT")
python -u MMaDA/inference/t2s_infer.py "${ARGS[@]}" "${REST_ARGS[@]}"
# Example:
# bash MMaDA/inference/run_t2s.sh MODE=free GUIDANCE_SCALE=1.5 TIMESTEPS=60 SEQ_LEN=786 BATCH_SIZE=1 OUTPUT_DIR=inference/outputs/t2s_cli TRAIN_CONFIG=MMaDA/configs/omada_pretraining_stage1-3.yaml CKPT_ROOT=ckpts/omada/omada-training-stage1_3rd/checkpoint-315000/unwrapped_model
bash MMaDA/inference/run_t2s.sh \
WER=1 ASR=openai/whisper-large-v3 LANG=english TEXT_NORM=basic WER_SAMPLES=256 \
TIMESTEPS=256 BLOCKSIZE=256 MODE=mmu GUIDANCE_SCALE=3.0 SEQ_LEN=512 BATCH_SIZE=4 \
OUTPUT_DIR=inference/outputs/t2s_cli \
TRAIN_CONFIG=MMaDA/configs/omada_pretraining_stage1-3.yaml \
CKPT_ROOT=ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model