# visualize_attention_logits.py
import argparse
from pathlib import Path

import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F


def parse_args():
    p = argparse.ArgumentParser("Visualize pre-softmax attention logits heatmaps")
    p.add_argument("--dump-path", type=str, required=True)
    p.add_argument("--layer", type=int, default=0)
    p.add_argument("--head", type=int, default=0)
    p.add_argument("--out-path", type=str, default="attn_logits_heatmap.png")
    p.add_argument("--alpha", type=float, default=0.6)
    p.add_argument("--only-heatmap", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    dump = torch.load(args.dump_path, map_location="cpu")

    all_scores = dump["all_scores"]      # List[Tensor[B,H,N,N]]
    inputs = dump["inputs"]
    meta = dump["meta"]

    image_size = meta["image_size"]
    patch_size = meta["patch_size"]
    num_task_tokens = meta["num_task_tokens"]
    row = meta["row"]
    col = meta["col"]

    layer_id = args.layer
    head_id = args.head

    if layer_id < 0 or layer_id >= len(all_scores):
        raise IndexError(f"layer={layer_id} 越界，只有 {len(all_scores)} 层")

    scores = all_scores[layer_id]  # (B,H,N,N)
    if scores.size(1) == 0:
        raise RuntimeError(f"该层没有 head，形状: {scores.shape}")

    # 像素 -> patch -> query token index
    ph = image_size // patch_size
    pw = image_size // patch_size
    patch_row = row // patch_size
    patch_col = col // patch_size
    query_index = patch_row * pw + patch_col
    seq_query_index = num_task_tokens + query_index

    print(
        f"像素 (row={row}, col={col}) -> patch ({patch_row}, {patch_col}) -> "
        f"query token index={seq_query_index}"
    )

    # 对所有 head 取平均
    scores_layer = scores[0]              # (H,N,N)
    row_scores_all_heads = scores_layer[:, seq_query_index, :]  # (H,N)
    row_scores = row_scores_all_heads.mean(dim=0)               # (N,)

    # 只看像素 token
    pix_scores = row_scores[num_task_tokens:]  # (N_pix,)
    if pix_scores.numel() != ph * pw:
        raise RuntimeError(
            f"像素 token 数不匹配: pix_scores={pix_scores.numel()}, ph*pw={ph*pw}"
        )

    # 为了可视化，把 logits 做一个非线性映射，比如 tanh 或 clamp 后归一化
    patch_map = pix_scores.view(ph, pw).unsqueeze(0).unsqueeze(0)  # (1,1,ph,pw)
    # 可以先减均值再除以 std，或者用 torch.tanh 缩放
    patch_map = torch.tanh(patch_map / (patch_map.std() + 1e-6))

    heatmap = F.interpolate(
        patch_map,
        size=(image_size, image_size),
        mode="bilinear",
        align_corners=False,
    )[0, 0]

    # 归一化到 [0,1]
    heatmap = heatmap - heatmap.min()
    heatmap = heatmap / (heatmap.max() + 1e-9)
    heatmap_np = heatmap.numpy()

    # 原图 -> RGB
    input_img = inputs[0].numpy()
    cmap_img = plt.get_cmap("tab20")
    img_norm = input_img / max(input_img.max(), 1)
    base_rgb = cmap_img(img_norm)[:, :, :3]

    fig, ax = plt.subplots(figsize=(6, 6))
    if not args.only_heatmap:
        ax.imshow(base_rgb, interpolation="nearest")
        ax.imshow(
            heatmap_np,
            cmap="hot",
            alpha=args.alpha,
            interpolation="nearest",
        )
    else:
        ax.imshow(
            heatmap_np,
            cmap="hot",
            interpolation="nearest",
        )

    task_name = meta.get("task_name", "")
    ex_idx = meta.get("example_index", 0)
    ax.set_title(
        f"Logits attention (heads-avg) | Task {task_name}, ex {ex_idx}, "
        f"layer {layer_id}, pixel ({row},{col})"
    )
    ax.axis("off")

    out_path = Path(args.out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight", dpi=200)
    plt.close(fig)
    print(f"保存 logits 注意力图到 {out_path}")


if __name__ == "__main__":
    main()