#!/usr/bin/env bash
set -e

# 你可以在这里指定使用哪些 GPU
# 第一台机器例如用 0,1
export CUDA_VISIBLE_DEVICES=1,2,3,4

# 公共超参
DATA_ROOT="raw_data/ARC-AGI"
TRAIN_SPLIT="training"
EPOCHS=100
BATCH_SIZE=32
IMAGE_SIZE=64
PATCH_SIZE=2
LR=3e-4
WEIGHT_DECAY=0
EMBED_DIM=512
NUM_HEADS=8
NUM_COLORS=12
LR_SCHED="cosine"
VIS_EVERY=50

# 需要 sweep 的参数
LOOP_STEPS_LIST=(1 2 3 6)      # max-loop-steps
BLOCK_LIST=(8)        # depth 或 loop-core-depth

# 可以通过这个开关控制只跑哪些 block
# 例如在一台机器上只跑 2 和 10：
RUN_BLOCK_LIST=(8)

# 是否使用 wandb
USE_WANDB="--use-wandb"

# 日志根目录
SAVE_ROOT="saves/sweep_loop_vs_vit"

#---------------------------------------
# 辅助函数：检查一个值是否在数组里
#---------------------------------------
in_list () {
  local val="$1"; shift
  for x in "$@"; do
    if [[ "$x" == "$val" ]]; then
      return 0
    fi
  done
  return 1
}

#---------------------------------------
# 主循环
#---------------------------------------
for LOOP_STEPS in "${LOOP_STEPS_LIST[@]}"; do
  for BLOCKS in "${BLOCK_LIST[@]}"; do

    # 如果 BLOCKS 不在 RUN_BLOCK_LIST 中，就跳过
    if ! in_list "$BLOCKS" "${RUN_BLOCK_LIST[@]}"; then
      continue
    fi

    # baseline: loop=1 & blocks=10
    if [[ "$LOOP_STEPS" -eq 1 && "$BLOCKS" -eq 10 ]]; then
      echo "=== Running BASELINE: depth=10, no loop ==="

      EXP_NAME="baseline_depth${BLOCKS}_loop${LOOP_STEPS}"
      SAVE_DIR="${SAVE_ROOT}/${EXP_NAME}"
      CKPT_FINAL="${SAVE_DIR}/checkpoint_final.pt"
      CKPT_BEST="${SAVE_DIR}/checkpoint_best.pt"

      mkdir -p "$SAVE_DIR"

      torchrun --nproc_per_node=4 \
        --rdzv_backend=c10d \
        --rdzv_endpoint=127.0.0.1:29642 \
        offline_train_ARC.py \
          --epochs $EPOCHS \
          --depth "$BLOCKS" \
          --batch-size $BATCH_SIZE \
          --image-size $IMAGE_SIZE \
          --patch-size $PATCH_SIZE \
          --learning-rate $LR \
          --weight-decay $WEIGHT_DECAY \
          --embed-dim $EMBED_DIM \
          --num-heads $NUM_HEADS \
          --include-rearc \
          --num-colors $NUM_COLORS \
          --data-root "$DATA_ROOT" \
          --train-split "$TRAIN_SPLIT" \
          --wandb-project "VisionARC" \
          --wandb-run-name "$EXP_NAME" \
          --save-path "$CKPT_FINAL" \
          --best-save-path "$CKPT_BEST" \
          --lr-scheduler "$LR_SCHED" \
          --architecture "vit" \
          --vis-every $VIS_EVERY \
          --distributed \
          $USE_WANDB

    else
      echo "=== Running LOOP model: loop-core-depth=${BLOCKS}, max-loop-steps=${LOOP_STEPS} ==="

      EXP_NAME="loop_depth${BLOCKS}_steps${LOOP_STEPS}"
      SAVE_DIR="${SAVE_ROOT}/${EXP_NAME}"
      CKPT_FINAL="${SAVE_DIR}/checkpoint_final.pt"
      CKPT_BEST="${SAVE_DIR}/checkpoint_best.pt"

      mkdir -p "$SAVE_DIR"

      torchrun --nproc_per_node=4 \
        offline_train_loop_ARC.py \
          --epochs $EPOCHS \
          --loop-core-depth "$BLOCKS" \
          --max-loop-steps "$LOOP_STEPS" \
          --min-loop-steps 2 \
          --batch-size $BATCH_SIZE \
          --image-size $IMAGE_SIZE \
          --patch-size $PATCH_SIZE \
          --learning-rate $LR \
          --weight-decay $WEIGHT_DECAY \
          --embed-dim $EMBED_DIM \
          --num-heads $NUM_HEADS \
          --include-rearc \
          --num-colors $NUM_COLORS \
          --data-root "$DATA_ROOT" \
          --train-split "$TRAIN_SPLIT" \
          --wandb-project "VisionARC" \
          --wandb-run-name "$EXP_NAME" \
          --save-path "$CKPT_FINAL" \
          --best-save-path "$CKPT_BEST" \
          --lr-scheduler "$LR_SCHED" \
          --vis-every $VIS_EVERY \
          --distributed \
          --gate-entropy-weight 5e-4 \
          --loop-penalty-weight 1e-3 \
          --train-dynamic-exit \
          --eval-dynamic-exit \
          $USE_WANDB
    fi

  done
done