# visualize_attention_layerwise_grid.py
import argparse
from pathlib import Path

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


def parse_args():
    p = argparse.ArgumentParser("Visualize layer-wise averaged attention maps (grid)")
    p.add_argument("--dump-path", type=str, required=True,
                   help="analyze_attention.py 保存的 *.pt 文件路径")
    p.add_argument("--out-path", type=str, default="attn_layerwise_grid.png",
                   help="输出 PNG 路径")
    p.add_argument("--alpha", type=float, default=0.6,
                   help="heatmap 覆盖原图的透明度")
    p.add_argument("--only-heatmap", action="store_true",
                   help="只画 heatmap，不叠原图")
    return p.parse_args()

def get_valid_range(mask):
    """
    从 attention_mask (H, W) 中提取有效区域的边界。
    假设 mask 中 1 (True) 代表有效，0 (False) 代表 Padding。
    """
    if isinstance(mask, torch.Tensor):
        mask = mask.cpu().numpy()
    
    # 找到所有非零元素的索引
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    
    if not np.any(rows) or not np.any(cols):
        # 如果全是 0 (异常情况)，返回原图大小
        return 0, mask.shape[0], 0, mask.shape[1]

    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    return rmin, rmax + 1, cmin, cmax + 1


def main():
    args = parse_args()
    dump = torch.load(args.dump_path, map_location="cpu")

    all_attn = dump["all_attn"]          # List[Tensor[B, H, N, N]]
    inputs = dump["inputs"]              # (1, H, W)
    attention_mask = dump.get("attention_mask") 
    meta = dump["meta"]

    image_size = meta["image_size"]
    patch_size = meta["patch_size"]
    num_task_tokens = meta["num_task_tokens"]

    # --- 新增：计算有效区域 ---
    # attention_mask shape: (1, H, W) -> (H, W)
    if attention_mask is not None:
        rmin, rmax, cmin, cmax = get_valid_range(attention_mask[0])
    else:
        # 如果没有 mask，默认全图
        rmin, rmax, cmin, cmax = 0, image_size, 0, image_size
    
    print(f"有效区域: rows [{rmin}:{rmax}], cols [{cmin}:{cmax}]")
    # ------------------------

    num_layers = len(all_attn)

    ph = image_size // patch_size
    pw = image_size // patch_size
    N_pix = ph * pw

    # 原图 -> RGB
    input_img = inputs[0].numpy()

    # --- 新增：裁剪原图 ---
    input_img_cropped = input_img[rmin:rmax, cmin:cmax]
    # --------------------

    cmap_img = plt.get_cmap("tab20")
    img_norm = input_img_cropped / max(input_img.max(), 1)
    base_rgb = cmap_img(img_norm)[:, :, :3]

    # 网格布局：默认 2 行 * 5 列，适合 depth=10
    n_rows = 2
    n_cols = (num_layers + n_rows - 1) // n_rows

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows))
    axes = axes.reshape(n_rows, n_cols)

    for layer_id in range(num_layers):
        attn = all_attn[layer_id]  # (B, H_head, N, N)
        B, H_head, N, _ = attn.shape

        # 只看像素 token <-> 像素 token 的子矩阵
        attn_pix = attn[:, :, num_task_tokens:, num_task_tokens:]  # (B, H, N_pix, N_pix)
        if attn_pix.size(-1) != N_pix or attn_pix.size(-2) != N_pix:
            raise RuntimeError(
                f"[Layer {layer_id}] 像素 token 子矩阵尺寸不匹配: "
                f"got {(attn_pix.size(-2), attn_pix.size(-1))}, expected {(N_pix, N_pix)}"
            )

        # 对 (batch, head, query) 取平均 -> 每个 key pixel 的平均被关注权重
        layer_map_1d = attn_pix.mean(dim=(0, 1, 2))  # (N_pix,)
        layer_map_1d = layer_map_1d / (layer_map_1d.sum() + 1e-9)

        patch_map = layer_map_1d.view(ph, pw).unsqueeze(0).unsqueeze(0)  # (1,1,ph,pw)
        heatmap = F.interpolate(
            patch_map,
            size=(image_size, image_size),
            mode="bilinear",
            align_corners=False,
        )[0, 0]
        # --- 新增：裁剪 Heatmap ---
        heatmap = heatmap[rmin:rmax, cmin:cmax]
        # ------------------------
        heatmap = heatmap.clamp(min=0)
        heatmap = heatmap / (heatmap.max() + 1e-9)
        heatmap = torch.sqrt(heatmap)
        heatmap_np = heatmap.numpy()

        r = layer_id // n_cols
        c = layer_id % n_cols
        ax = axes[r, c]

        if not args.only_heatmap:
            ax.imshow(base_rgb, interpolation="nearest")
            ax.imshow(
                heatmap_np,
                cmap="hot",
                alpha=args.alpha,
                interpolation="nearest",
            )
        else:
            ax.imshow(
                heatmap_np,
                cmap="hot",
                interpolation="nearest",
            )

        ax.set_title(f"Layer {layer_id}")
        ax.axis("off")

    # 把多余 subplot 去掉（如果 num_layers 不是 n_rows*n_cols）
    for idx in range(num_layers, n_rows * n_cols):
        r = idx // n_cols
        c = idx % n_cols
        axes[r, c].axis("off")

    task_name = meta.get("task_name", "")
    ex_idx = meta.get("example_index", 0)
    fig.suptitle(
        f"Layer-wise avg attention (heads-avg) | Task {task_name}, ex {ex_idx}",
        fontsize=12,
    )

    out_path = Path(args.out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig.savefig(out_path, dpi=200)
    plt.close(fig)
    print(f"保存 layer-wise 平均注意力 grid 图到 {out_path}")


if __name__ == "__main__":
    main()