Spaces:
Running
on
Zero
Running
on
Zero
| set -euo pipefail | |
| # Editable defaults (can be overridden by KEY=VALUE pairs or by passing CLI flags) | |
| TRAIN_CONFIG=${TRAIN_CONFIG:-MMaDA/configs/omada_pretraining_stage1-3.yaml} | |
| CKPT_ROOT=${CKPT_ROOT:-ckpts/omada/omada-training-stage1_3rd} | |
| INFER_CONFIG=${INFER_CONFIG:-} | |
| CHECKPOINTS=${CHECKPOINTS:-} | |
| # Generation params | |
| MODE=${MODE:-} | |
| GUIDANCE_SCALE=${GUIDANCE_SCALE:-} | |
| TEMPERATURE=${TEMPERATURE:-} | |
| TIMESTEPS=${TIMESTEPS:-} | |
| SEQ_LEN=${SEQ_LEN:-} | |
| NOISE_SCHEDULE=${NOISE_SCHEDULE:-} | |
| NOISE_TYPE=${NOISE_TYPE:-} | |
| BATCH_SIZE=${BATCH_SIZE:-} | |
| OUTPUT_DIR=${OUTPUT_DIR:-} | |
| WER_ASR_MODEL=${WER_ASR_MODEL:-} | |
| WER_LANGUAGE=${WER_LANGUAGE:-} | |
| WER_MAX_SAMPLES=${WER_MAX_SAMPLES:-} | |
| TEXT_NORM=${TEXT_NORM:-} | |
| BLOCK_LENGTH=${BLOCK_LENGTH:-} | |
| MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-} | |
| AUDIO_CODEBOOK_SIZE=${AUDIO_CODEBOOK_SIZE:-} | |
| # Aliases for convenience | |
| CFG=${CFG:-} | |
| TIMESTEP=${TIMESTEP:-} | |
| WER=${WER:-} | |
| ASR=${ASR:-} | |
| LANG=${LANG:-} | |
| WER_SAMPLES=${WER_SAMPLES:-} | |
| # Dataset params | |
| SUBSET=${SUBSET:-} | |
| SPLIT=${SPLIT:-} | |
| LIMIT=${LIMIT:-} | |
| REST_ARGS=() | |
| for arg in "$@"; do | |
| case "$arg" in | |
| *=*) eval "$arg" ;; | |
| *) REST_ARGS+=("$arg") ;; | |
| esac | |
| done | |
| # Apply alias variables if provided | |
| [[ -n "$CFG" ]] && TRAIN_CONFIG="$CFG" | |
| [[ -n "$TIMESTEP" ]] && TIMESTEPS="$TIMESTEP" | |
| # Enable WER with defaults when WER toggle is set | |
| if [[ -n "$WER" && "$WER" != "0" ]]; then | |
| : "${WER_ASR_MODEL:=${ASR:-openai/whisper-large-v3}}" | |
| : "${WER_LANGUAGE:=${LANG:-english}}" | |
| : "${WER_MAX_SAMPLES:=${WER_SAMPLES:-64}}" | |
| fi | |
| ARGS=( | |
| --train_config "$TRAIN_CONFIG" | |
| --ckpt_root "$CKPT_ROOT" | |
| ) | |
| if [[ -n "$INFER_CONFIG" ]]; then | |
| ARGS+=(--infer_config "$INFER_CONFIG") | |
| fi | |
| if [[ -n "$CHECKPOINTS" ]]; then | |
| IFS=',' read -r -a CK_ARR <<< "$CHECKPOINTS" | |
| for c in "${CK_ARR[@]}"; do | |
| ARGS+=(--checkpoint "$c") | |
| done | |
| fi | |
| [[ -n "$MODE" ]] && ARGS+=(--mode "$MODE") | |
| [[ -n "$GUIDANCE_SCALE" ]] && ARGS+=(--guidance_scale "$GUIDANCE_SCALE") | |
| [[ -n "$TEMPERATURE" ]] && ARGS+=(--temperature "$TEMPERATURE") | |
| [[ -n "$TIMESTEPS" ]] && ARGS+=(--timesteps "$TIMESTEPS") | |
| [[ -n "$SEQ_LEN" ]] && ARGS+=(--seq_len "$SEQ_LEN") | |
| [[ -n "$NOISE_SCHEDULE" ]] && ARGS+=(--noise_schedule "$NOISE_SCHEDULE") | |
| [[ -n "$NOISE_TYPE" ]] && ARGS+=(--noise_type "$NOISE_TYPE") | |
| [[ -n "$BATCH_SIZE" ]] && ARGS+=(--batch_size "$BATCH_SIZE") | |
| [[ -n "$OUTPUT_DIR" ]] && ARGS+=(--output_dir "$OUTPUT_DIR") | |
| [[ -n "$WER_ASR_MODEL" ]] && ARGS+=(--wer_asr_model "$WER_ASR_MODEL") | |
| [[ -n "$WER_LANGUAGE" ]] && ARGS+=(--wer_language "$WER_LANGUAGE") | |
| [[ -n "$WER_MAX_SAMPLES" ]] && ARGS+=(--wer_max_samples "$WER_MAX_SAMPLES") | |
| [[ -n "$TEXT_NORM" ]] && ARGS+=(--text_norm "$TEXT_NORM") | |
| [[ -n "$BLOCK_LENGTH" ]] && ARGS+=(--block_length "$BLOCK_LENGTH") | |
| [[ -n "$MAX_NEW_TOKENS" ]] && ARGS+=(--max_new_tokens "$MAX_NEW_TOKENS") | |
| [[ -n "$AUDIO_CODEBOOK_SIZE" ]] && ARGS+=(--audio_codebook_size "$AUDIO_CODEBOOK_SIZE") | |
| [[ -n "$SUBSET" ]] && ARGS+=(--subset "$SUBSET") | |
| [[ -n "$SPLIT" ]] && ARGS+=(--split "$SPLIT") | |
| [[ -n "$LIMIT" ]] && ARGS+=(--limit "$LIMIT") | |
| python -u MMaDA/inference/t2s_infer.py "${ARGS[@]}" "${REST_ARGS[@]}" | |
| # Example: | |
| # bash MMaDA/inference/run_t2s.sh MODE=free GUIDANCE_SCALE=1.5 TIMESTEPS=60 SEQ_LEN=786 BATCH_SIZE=1 OUTPUT_DIR=inference/outputs/t2s_cli TRAIN_CONFIG=MMaDA/configs/omada_pretraining_stage1-3.yaml CKPT_ROOT=ckpts/omada/omada-training-stage1_3rd/checkpoint-315000/unwrapped_model | |
| bash MMaDA/inference/run_t2s.sh \ | |
| WER=1 ASR=openai/whisper-large-v3 LANG=english TEXT_NORM=basic WER_SAMPLES=256 \ | |
| TIMESTEPS=256 BLOCKSIZE=256 MODE=mmu GUIDANCE_SCALE=3.0 SEQ_LEN=512 BATCH_SIZE=4 \ | |
| OUTPUT_DIR=inference/outputs/t2s_cli \ | |
| TRAIN_CONFIG=MMaDA/configs/omada_pretraining_stage1-3.yaml \ | |
| CKPT_ROOT=ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model | |