# 1. 创建环境
conda create -n visarc python==3.10 -y
conda activate visarc

# 2. 安装依赖
pip install -r requirements.txt


# 3. 配置 wandb（推荐，api是wenjie的）
export WANDB_API_KEY=2bbad674b7a265781cc8f48ab93dad9d0a894ac6
wandb login  # 首次使用时


cd /VARC  # 仓库根目录

# 构建 augmented 数据（ARC-1 + RE-ARC）
python augment_data.py

# 下面是实验的脚本路径
script/run_varc_loop_sweep.sh

# 下面是脚本需要改的地方：
#!/usr/bin/env bash
set -e

# 1. 指定使用哪些 GPU（例如一台机器上用 4 张卡，推荐8卡）
export CUDA_VISIBLE_DEVICES=0,1,2,3

# 2. 公共超参（一般不用改）
DATA_ROOT="raw_data/ARC-AGI"
TRAIN_SPLIT="training"
EPOCHS=100
BATCH_SIZE=32
IMAGE_SIZE=64
PATCH_SIZE=2
LR=3e-4
WEIGHT_DECAY=0
EMBED_DIM=512
NUM_HEADS=8
NUM_COLORS=12
LR_SCHED="cosine"
VIS_EVERY=50

# 3. Sweep 的参数
#   - LOOP_STEPS_LIST: max-loop-steps ∈ {1, 2, 3, 6}
#   - BLOCK_LIST: loop-core-depth / depth ∈ {8}（只跑 block=8 的线）
LOOP_STEPS_LIST=(1 2 3 6)
BLOCK_LIST=(8)

# 4. 控制只跑哪些 block（这里只跑 block=8，本机专门负责这一条线）
RUN_BLOCK_LIST=(8)

# 5. 日志和 checkpoint 根目录
SAVE_ROOT="saves/sweep_loop_vs_vit"

# 后面是主循环和 torchrun 调用逻辑，已经写好，不需要改。