# visualize_attention_layerwise.py
import argparse
from pathlib import Path

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


def parse_args():
    p = argparse.ArgumentParser("Visualize layer-wise averaged attention maps")
    p.add_argument("--dump-path", type=str, required=True,
                   help="analyze_attention.py 保存的 *.pt 文件路径")
    p.add_argument("--layer", type=int, default=0,
                   help="可视化第几层（0-based）")
    p.add_argument("--out-path", type=str, default="attn_layerwise.png",
                   help="输出 PNG 路径")
    p.add_argument("--alpha", type=float, default=0.6,
                   help="heatmap 覆盖原图的透明度")
    p.add_argument("--only-heatmap", action="store_true",
                   help="只画 heatmap，不叠原图")
    return p.parse_args()


def main():
    args = parse_args()
    dump = torch.load(args.dump_path, map_location="cpu")

    all_attn = dump["all_attn"]          # List[Tensor[B, H, N, N]]
    inputs = dump["inputs"]              # (1, H, W)
    meta = dump["meta"]

    image_size = meta["image_size"]
    patch_size = meta["patch_size"]
    num_task_tokens = meta["num_task_tokens"]

    layer_id = args.layer
    if layer_id < 0 or layer_id >= len(all_attn):
        raise IndexError(f"layer={layer_id} 越界，all_attn 只有 {len(all_attn)} 层")

    attn = all_attn[layer_id]  # (B, H, N, N)
    B, H_head, N, _ = attn.shape

    ph = image_size // patch_size
    pw = image_size // patch_size
    N_pix = ph * pw

    # 只看像素 token <-> 像素 token 的子矩阵
    attn_pix = attn[:, :, num_task_tokens:, num_task_tokens:]  # (B, H, N_pix, N_pix)
    if attn_pix.size(-1) != N_pix or attn_pix.size(-2) != N_pix:
        raise RuntimeError(
            f"[Layer {layer_id}] 像素 token 子矩阵尺寸不匹配: "
            f"got {(attn_pix.size(-2), attn_pix.size(-1))}, expected {(N_pix, N_pix)}"
        )

    # softmax 权重已经是 per-query 的概率分布：
    # 现在对 (batch, head, query) 全部平均 -> 得到每个 key pixel 的平均被关注权重
    layer_map_1d = attn_pix.mean(dim=(0, 1, 2))  # (N_pix,)
    layer_map_1d = layer_map_1d / (layer_map_1d.sum() + 1e-9)

    patch_map = layer_map_1d.view(ph, pw).unsqueeze(0).unsqueeze(0)  # (1,1,ph,pw)
    heatmap = F.interpolate(
        patch_map,
        size=(image_size, image_size),
        mode="bilinear",
        align_corners=False,
    )[0, 0]
    heatmap = heatmap.clamp(min=0)
    heatmap = heatmap / (heatmap.max() + 1e-9)
    heatmap = torch.sqrt(heatmap)
    heatmap_np = heatmap.numpy()

    # 原图 -> RGB
    input_img = inputs[0].numpy()
    cmap_img = plt.get_cmap("tab20")
    img_norm = input_img / max(input_img.max(), 1)
    base_rgb = cmap_img(img_norm)[:, :, :3]

    fig, ax = plt.subplots(figsize=(6, 6))
    if not args.only_heatmap:
        ax.imshow(base_rgb, interpolation="nearest")
        ax.imshow(
            heatmap_np,
            cmap="hot",
            alpha=args.alpha,
            interpolation="nearest",
        )
    else:
        ax.imshow(
            heatmap_np,
            cmap="hot",
            interpolation="nearest",
        )

    task_name = meta.get("task_name", "")
    ex_idx = meta.get("example_index", 0)
    ax.set_title(
        f"Layer-wise avg attention (heads-avg) | Task {task_name}, ex {ex_idx}, layer {layer_id}"
    )
    ax.axis("off")

    out_path = Path(args.out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight", dpi=200)
    plt.close(fig)
    print(f"保存 layer-wise 平均注意力热力图到 {out_path}")


if __name__ == "__main__":
    main()