AIDAS-Omni-Modal-Diffusion / MMaDA /script /train_omada_instruction.sh
jaeikkim
Reinit Space without binary assets
7bfbdc3
#!/usr/bin/env bash
# Example manual launches retained for reference:
# accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 0 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_inst.py config=/home/work/AIDAS/MMaDA/configs/omada_instruction_tuning.yaml
# accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 1 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_inst.py config=/home/work/AIDAS/MMaDA/configs/omada_instruction_tuning.yaml
# accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 2 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_inst.py config=/home/work/AIDAS/MMaDA/configs/omada_instruction_tuning.yaml
# accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 3 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_inst.py config=/home/work/AIDAS/MMaDA/configs/omada_instruction_tuning.yaml
export AIDAS_TRAIN_HOSTS="main1 sub1 sub2 sub3"
set -euo pipefail
PROJECT_ROOT="/home/work/AIDAS"
CONFIG_FILE="${PROJECT_ROOT}/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml"
TRAIN_SCRIPT="${PROJECT_ROOT}/MMaDA/training/train_omada_inst.py"
EXPERIMENT_CFG="${PROJECT_ROOT}/MMaDA/configs/omada_instruction_tuning.yaml"
LOG_DIR="${PROJECT_ROOT}/logs"
MAIN_PORT="${MAIN_PORT:-8888}"
REMOTE_SETUP="${REMOTE_SETUP:-source ~/.bashrc && conda activate mmada}"
NCCL_DEBUG_LEVEL="${NCCL_DEBUG_LEVEL:-INFO}"
if [[ -z "${AIDAS_TRAIN_HOSTS:-}" ]]; then
echo "Set AIDAS_TRAIN_HOSTS=\"host0 host1 host2 host3\" before running this script." >&2
exit 1
fi
read -r -a HOSTS <<< "${AIDAS_TRAIN_HOSTS}"
NUM_MACHINES=${#HOSTS[@]}
if (( NUM_MACHINES == 0 )); then
echo "AIDAS_TRAIN_HOSTS is empty." >&2
exit 1
fi
mkdir -p "$LOG_DIR"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
declare -a PIDS=()
declare -a HOST_LABELS=()
timestamp_lines() {
while IFS= read -r line; do
printf '%s %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$line"
done
}
stop_all() {
if (( ${#PIDS[@]} == 0 )); then
return
fi
echo "Stopping launched processes..."
for pid in "${PIDS[@]}"; do
if [[ -n "${pid:-}" ]] && kill -0 "$pid" 2>/dev/null; then
kill "$pid" >/dev/null 2>&1 || true
fi
done
}
on_signal() {
echo "Signal received, terminating all ranks."
stop_all
exit 1
}
trap on_signal INT TERM
launch_rank() {
local host="$1"
local rank="$2"
local log_file="$3"
local host_label="$4"
local base_cmd
if [[ -n "${REMOTE_SETUP}" ]]; then
base_cmd="${REMOTE_SETUP} && cd ${PROJECT_ROOT} && env NCCL_DEBUG=${NCCL_DEBUG_LEVEL} NCCL_SHM_DISABLE=1 NCCL_ASYNC_ERROR_HANDLING=1 accelerate launch --config_file ${CONFIG_FILE} --num_machines ${NUM_MACHINES} --machine_rank ${rank} --main_process_port ${MAIN_PORT} ${TRAIN_SCRIPT} config=${EXPERIMENT_CFG}"
else
base_cmd="cd ${PROJECT_ROOT} && env NCCL_DEBUG=${NCCL_DEBUG_LEVEL} NCCL_SHM_DISABLE=1 NCCL_ASYNC_ERROR_HANDLING=1 accelerate launch --config_file ${CONFIG_FILE} --num_machines ${NUM_MACHINES} --machine_rank ${rank} --main_process_port ${MAIN_PORT} ${TRAIN_SCRIPT} config=${EXPERIMENT_CFG}"
fi
local escaped_cmd
escaped_cmd=$(printf '%q' "$base_cmd")
if [[ "$host" == "localhost" || "$host" == "$(hostname)" || "$host" == "$(hostname -f)" ]]; then
echo "[rank ${rank}] running locally (${host_label}), logging to ${log_file}"
stdbuf -oL -eL bash -lc "$base_cmd" 2>&1 | timestamp_lines >"$log_file" &
else
local dest="${SSH_USER:-$USER}@${host}"
echo "[rank ${rank}] ssh ${dest}, logging to ${log_file}"
ssh "$dest" "bash -lc $escaped_cmd" 2>&1 | timestamp_lines >"$log_file" &
fi
PIDS[$rank]=$!
HOST_LABELS[$rank]="$host_label"
}
for idx in "${!HOSTS[@]}"; do
host="${HOSTS[$idx]}"
safe_host=${host//[^A-Za-z0-9_.-]/_}
log_file="${LOG_DIR}/train_inst_${TIMESTAMP}_rank${idx}_${safe_host}.log"
launch_rank "$host" "$idx" "$log_file" "$safe_host"
done
echo "All nodes launched. Tail logs under ${LOG_DIR}."
for rank in "${!PIDS[@]}"; do
pid="${PIDS[$rank]}"
[[ -n "${pid:-}" ]] || continue
if ! wait "$pid"; then
status=$?
echo "[rank ${rank}] (${HOST_LABELS[$rank]}) exited with status ${status}"
stop_all
exit $status
fi
done