# analyze_attention.py
import argparse
from pathlib import Path

import torch
import torch.nn.functional as F

from src.ARC_loader import build_dataloaders
from src.attn_hook import load_offline_vit_with_attn


def parse_args():
    p = argparse.ArgumentParser("Analyze VARC-ViT attention maps")
    p.add_argument(
        "--ckpt",
        type=str,
        required=True,
        help="Checkpoint 路径，比如 saves/offline_train_ViT/checkpoint_final.pt",
    )
    p.add_argument("--data-root", type=str, default="raw_data/ARC-AGI")
    p.add_argument("--image-size", type=int, default=64)
    p.add_argument("--num-colors", type=int, default=12)
    p.add_argument("--embed-dim", type=int, default=512)
    # 注意：offline_train_ViT 的 MLP 维度是 512
    p.add_argument("--mlp-dim", type=int, default=512)
    p.add_argument("--depth", type=int, default=10)
    p.add_argument("--num-heads", type=int, default=8)
    p.add_argument("--patch-size", type=int, default=2)
    p.add_argument("--num-task-tokens", type.int, default=1)

    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--task-name", type=str, required=True, help="例如 1da012fc")
    p.add_argument(
        "--example-index",
        type=int,
        default=0,
        help="该任务下第几个示例（0-based）",
    )
    p.add_argument(
        "--row",
        type=int,
        default=3,
        help="关注的输出像素行号（0-based）",
    )
    p.add_argument(
        "--col",
        type=int,
        default=4,
        help="关注的输出像素列号（0-based）",
    )
    p.add_argument(
        "--output-path",
        type=str,
        default="attn_dump.pt",
        help="保存 attention 的文件路径",
    )
    p.add_argument("--layer", type=int, default=0)
    p.add_argument("--head", type=int, default=0)
    return p.parse_args()


def main():
    args = parse_args()
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    # ------------- 1. 先用 training split 拿 num_tasks -------------
    class DummyArgs:
        pass

    da_train = DummyArgs()
    da_train.data_root = args.data_root
    da_train.train_split = f"eval_color_permute_ttt_9/{args.task_name}"
    da_train.eval_split = f"eval_color_permute_ttt_9/{args.task_name}"
    da_train.eval_subset = "test"
    da_train.image_size = args.image_size
    da_train.num_colors = args.num_colors
    da_train.batch_size = 1
    da_train.num_workers = 0
    da_train.include_rearc = False
    da_train.rearc_path = "raw_data/re_arc"
    da_train.rearc_limit = -1
    da_train.patch_size = args.patch_size
    da_train.num_task_tokens = args.num_task_tokens
    da_train.fix_scale_factor = 2
    da_train.disable_translation = False
    da_train.disable_resolution_augmentation = False

    print("Start loading training data (for num_tasks)...")
    train_dataset, _, _, _, *_ = build_dataloaders(
        da_train, distributed=False, rank=0, world_size=1
    )
    print("Finish loading training data.")
    num_tasks = train_dataset.num_tasks

    # 从 offline 训练集的 task_lookup 中找到当前 task 的整数 id
    task_lookup = getattr(train_dataset, "task_lookup", None)
    if task_lookup is None:
        raise RuntimeError("train_dataset 没有 task_lookup，无法推断 task_id。")

    # 打印 task_lookup 的内容以便调试
    print(f"task_lookup keys: {list(task_lookup.keys())}")

    if args.task_name in task_lookup:
        task_id = task_lookup[args.task_name]
        print(f"在训练集 task_lookup 中找到 {args.task_name} -> task_id={task_id}")
    else:
        # 尝试用后缀匹配（有些实现会带前缀路径）
        candidates = [k for k in task_lookup.keys() if k.endswith(args.task_name)]
        print(f"后缀匹配候选项: {candidates}")
        if len(candidates) == 1:
            k = candidates[0]
            task_id = task_lookup[k]
            print(f"通过后缀匹配 {args.task_name} -> key {k} -> task_id={task_id}")
        else:
            # 如果任务名称是 README 中的示例任务，提供提示
            if args.task_name == "1da012fc":
                raise ValueError(
                    f"任务名称 {args.task_name} 是 README 中的示例任务，但未在 task_lookup 中找到。"
                    "请检查数据集是否正确，或确认 README 中的任务名称是否正确。"
                )
            else:
                raise ValueError(
                    f"在训练集 task_lookup 中找不到 {args.task_name}，"
                    f"且通过后缀匹配无法唯一确定（候选: {candidates}）。"
                )

    # ------------- 2. 再加载 eval TTT split，只包含指定 task -------------
    da_eval = DummyArgs()
    da_eval.data_root = args.data_root
    da_eval.train_split = f"eval_color_permute_ttt_9/{args.task_name}"
    da_eval.eval_split = f"eval_color_permute_ttt_9/{args.task_name}"
    da_eval.eval_subset = "test"
    da_eval.image_size = args.image_size
    da_eval.num_colors = args.num_colors
    da_eval.batch_size = 1
    da_eval.num_workers = 0
    da_eval.include_rearc = False
    da_eval.rearc_path = "raw_data/re_arc"
    da_eval.rearc_limit = -1
    da_eval.patch_size = args.patch_size
    da_eval.num_task_tokens = args.num_task_tokens
    da_eval.fix_scale_factor = 2
    da_eval.disable_translation = False
    da_eval.disable_resolution_augmentation = False

    print(f"Start loading eval TTT data for task {args.task_name} ...")
    eval_train_dataset, eval_train_loader, eval_dataset, eval_loader, *_ = build_dataloaders(
        da_eval, distributed=False, rank=0, world_size=1
    )
    print("Finish loading eval TTT data.")
    if args.example_index >= len(eval_train_dataset):
        raise IndexError(
            f"该任务的数据集只有 {len(eval_train_dataset)} 个样本，"
            f"example_index={args.example_index} 越界。"
        )

    sample = eval_train_dataset[args.example_index]

    # ------------- 3. 加载模型（用 offline 训练的 num_tasks）-------------
    # 将 num_tasks 设置为检查点中的任务数量
    num_tasks_in_ckpt = 400  # 根据错误信息，检查点中的任务数量是 400
    model = load_offline_vit_with_attn(
        ckpt_path=args.ckpt,
        num_tasks=num_tasks_in_ckpt,  # 使用检查点中的任务数量
        image_size=args.image_size,
        num_colors=args.num_colors,
        embed_dim=args.embed_dim,
        depth=args.depth,
        num_heads=args.num_heads,
        mlp_dim=args.mlp_dim,
        num_task_tokens=args.num_task_tokens,
        patch_size=args.patch_size,
        device=device.type,
    )

    # ------------- 4. 准备输入并前向，拿到 attention -------------
    inputs = sample["inputs"].unsqueeze(0).to(device)  # (1, H, W)
    attention_mask = sample["attention_mask"].unsqueeze(0).to(device)  # (1, H, W)
    # 构造 task_ids：当前整个 eval split 都是同一个任务
    task_ids = torch.full(
        (1,),
        fill_value=task_id,
        dtype=torch.long,
        device=device,
    )  # (1,)

    with torch.no_grad():
        logits, (all_scores, all_attn) = model(
            inputs,
            task_ids,
            attention_mask=attention_mask,
            return_attn=True,
        )

    assert all_attn is not None

    # ------------- 5. 打印选定像素的 query 行的一些统计 -------------
    ph = args.image_size // args.patch_size
    pw = args.image_size // args.patch_size
    patch_row = args.row // args.patch_size
    patch_col = args.col // args.patch_size
    query_index = patch_row * pw + patch_col
    seq_query_index = args.num_task_tokens + query_index

    print(
        f"选中像素 (row={args.row}, col={args.col}) -> "
        f"patch ({patch_row}, {patch_col}) -> "
        f"query token index={seq_query_index} (含 task token 偏移)"
    )

    layer_id = args.layer
    head_id = args.head
    # all_attn[layer_id]: (B, H, N, N)
    attn_layer_all_heads = all_attn[layer_id][0]       # (H, N, N)
    row_attn_all_heads = attn_layer_all_heads[:, seq_query_index, :]  # (H, N)
    row_attn = row_attn_all_heads.mean(dim=0)          # (N,)
    row_attn = row_attn / (row_attn.sum() + 1e-9)

    max_val, max_idx = torch.max(row_attn, dim=0)
    print(
        f"Layer {layer_id}, heads-avg, 该 query 的最大权重={max_val.item():.6f}, "
        f"对应 token index={max_idx.item()}"
    )
    self_weight = row_attn[seq_query_index].item()
    print(f"该 query 对自身的 normalized 权重={self_weight:.6f}")
    print(f"Row attention values (heads-avg): {row_attn}")

    if max_idx.item() >= args.num_task_tokens:
        pix_idx = max_idx.item() - args.num_task_tokens
        pix_row = (pix_idx // pw) * args.patch_size
        pix_col = (pix_idx % pw) * args.patch_size
        print(f"最大权重对应像素大致在 (row≈{pix_row}, col≈{pix_col})")

    # ------------- 6. 保存 dump，供 visualize_attention 使用 -------------
    dump = {
        "all_scores": [s.cpu() for s in all_scores],  # pre-softmax logits
        "all_attn": [a.cpu() for a in all_attn],  # List[Tensor[B, H, N, N]]
        "inputs": inputs.cpu(),
        "task_ids": task_ids.cpu(),
        "attention_mask": attention_mask.cpu(),
        "logits": logits.cpu(),
        "meta": {
            "task_name": args.task_name,
            "example_index": args.example_index,
            "row": args.row,
            "col": args.col,
            "image_size": args.image_size,
            "patch_size": args.patch_size,
            "num_task_tokens": args.num_task_tokens,
            "layer": args.layer,
            "head": args.head,
        },
    }
    out_path = Path(args.output_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(dump, out_path)
    print(f"Attention 已保存到 {out_path}")


if __name__ == "__main__":
    main()