# visualize_attention.py
import argparse
from pathlib import Path

import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F


def parse_args():
    p = argparse.ArgumentParser("Visualize VARC-ViT attention heatmaps")
    p.add_argument("--dump-path", type=str, required=True,
                   help="analyze_attention.py 保存的 *.pt 文件路径")
    p.add_argument("--layer", type=int, default=0,
                   help="可视化第几层 attention（0-based）")
    p.add_argument("--head", type=int, default=0,
                   help="可视化第几个 head（0-based）")
    p.add_argument("--out-path", type=str, default="attn_heatmap.png",
                   help="输出 PNG 路径")
    p.add_argument("--alpha", type=float, default=0.6,
                   help="heatmap 覆盖的透明度")
    p.add_argument("--only-heatmap", action="store_true",
                   help="只画 attention heatmap，不叠加原图，方便检查是否铺满")
    return p.parse_args()


def main():
    args = parse_args()
    dump = torch.load(args.dump_path, map_location="cpu")

    all_attn = dump["all_attn"]          # List[Tensor[B, H, N, N]]
    inputs = dump["inputs"]              # (1, H, W)
    meta = dump["meta"]

    image_size = meta["image_size"]
    patch_size = meta["patch_size"]
    num_task_tokens = meta["num_task_tokens"]
    row = meta["row"]
    col = meta["col"]

    layer_id = args.layer
    head_id = args.head

    if layer_id < 0 or layer_id >= len(all_attn):
        raise IndexError(f"layer={layer_id} 越界，all_attn 只有 {len(all_attn)} 层")
    attn = all_attn[layer_id]  # (B, num_heads, N, N)
    if attn.size(1) == 0:
        raise RuntimeError(f"该层没有 head，形状: {attn.shape}")

    # 只取 batch 0，并在 head 维度平均 -> (N, N)
    attn = attn[0].mean(dim=0)  # (N, N)

    # 计算 query token index
    # 与 analyze_attention.py 中保持一致：
    ph = image_size // patch_size
    pw = image_size // patch_size
    print(f"ph={ph}, pw={pw}, ph*pw={ph*pw}, seq_len(with task)={attn.shape[-1]}")

    patch_row = row // patch_size
    patch_col = col // patch_size
    query_index = patch_row * pw + patch_col
    seq_query_index = num_task_tokens + query_index

    print(
        f"像素 (row={row}, col={col}) -> patch ({patch_row}, {patch_col}) -> "
        f"query token index={seq_query_index}"
    )

    # 取该 query 对所有 key 的权重
    row_attn = attn[seq_query_index]  # (N,)
    row_attn = row_attn / (row_attn.sum() + 1e-9)
    print(f"row_attn shape={row_attn.shape}, min={row_attn.min().item():.6g}, "
          f"max={row_attn.max().item():.6g}")

    # 去掉前面的 task tokens，只保留像素 patch 的部分
    pix_attn = row_attn[num_task_tokens:]  # (N_pix,)
    if pix_attn.numel() != ph * pw:
        raise RuntimeError(
            f"像素 token 数不匹配: pix_attn={pix_attn.numel()}, ph*pw={ph*pw}"
        )

    # reshape 成 patch grid
    patch_heatmap = pix_attn.view(ph, pw).unsqueeze(0).unsqueeze(0)  # (1,1,ph,pw)
    print(f"patch_heatmap shape={patch_heatmap.shape}, "
          f"min={patch_heatmap.min().item():.6g}, "
          f"max={patch_heatmap.max().item():.6g}")

    # 上采样到原图分辨率
    heatmap = F.interpolate(
        patch_heatmap,
        size=(image_size, image_size),
        mode="bilinear",
        align_corners=False,
    )[0, 0]  # (H, W)
    heatmap = heatmap.clamp(min=0)
    heatmap = heatmap / (heatmap.max() + 1e-9)
    # 非线性拉伸：开平方，让中小权重更显眼
    heatmap = torch.sqrt(heatmap)
    heatmap_np = heatmap.numpy()

    print(f"Heatmap values: min={heatmap_np.min()}, max={heatmap_np.max()}")

    # 原始输入图像（离散颜色 0..num_colors-1）转成可视化 RGB
    # 这里简单用一个 colormap，把不同颜色 index 映射到颜色
    input_img = inputs[0].numpy()  # (H, W)
    # 使用 matplotlib 的 tab20 colormap
    cmap_img = plt.get_cmap("tab20")
    img_norm = input_img / max(input_img.max(), 1)
    base_rgb = cmap_img(img_norm)[:, :, :3]  # (H, W, 3)

    # 画图
    fig, ax = plt.subplots(figsize=(6, 6))
    if not args.only_heatmap:
        ax.imshow(base_rgb, interpolation="nearest")
        ax.imshow(
            heatmap_np,
            cmap="hot",
            alpha=args.alpha,
            interpolation="nearest",
        )
    else:
        # 只画 heatmap，本身就是 (H, W)，应该铺满整张图
        ax.imshow(
            heatmap_np,
            cmap="hot",
            interpolation="nearest",
        )
    ax.set_title(
        f"Task {meta.get('task_name', '')}, ex {meta.get('example_index', 0)}, "
        f"layer {layer_id}, head {head_id}, pixel ({row},{col})"
    )
    ax.axis("off")

    out_path = Path(args.out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight", dpi=200)
    plt.close(fig)
    print(f"保存热力图到 {out_path}")


if __name__ == "__main__":
    main()