# visualize_attention_layers_grid.py
import argparse
from pathlib import Path
import matplotlib.patches as patches

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


def parse_args():
    p = argparse.ArgumentParser("Visualize per-pixel attention across layers (grid)")
    p.add_argument("--dump-path", type=str, required=True,
                   help="analyze_attention.py 保存的 *.pt 文件路径")
    p.add_argument("--head", type=int, default=0,
                   help="使用第几个 head（0-based）")
    p.add_argument("--out-path", type=str, default="attn_layers_grid.png",
                   help="输出 PNG 路径")
    p.add_argument("--alpha", type=float, default=0.6,
                   help="heatmap 覆盖原图的透明度")
    p.add_argument("--only-heatmap", action="store_true",
                   help="只画 heatmap，不叠原图")
    return p.parse_args()

def get_valid_range(mask):
    """
    从 attention_mask (H, W) 中提取有效区域的边界。
    假设 mask 中 1 (True) 代表有效，0 (False) 代表 Padding。
    """
    if isinstance(mask, torch.Tensor):
        mask = mask.cpu().numpy()
    
    # 找到所有非零元素的索引
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    
    if not np.any(rows) or not np.any(cols):
        # 如果全是 0 (异常情况)，返回原图大小
        return 0, mask.shape[0], 0, mask.shape[1]

    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    return rmin, rmax + 1, cmin, cmax + 1


def main():
    args = parse_args()
    #dump = torch.load(args.dump_path, map_location="cpu")
    dump = torch.load(args.dump_path, map_location="cpu", weights_only=False)
    all_attn = dump["all_attn"]          # List[Tensor[B, H, N, N]]
    inputs = dump["inputs"]              # (1, H, W)
    targets = dump.get("targets")  
    attention_mask = dump.get("attention_mask") 
    meta = dump["meta"]

    image_size = meta["image_size"]
    patch_size = meta["patch_size"]
    num_task_tokens = meta["num_task_tokens"]
    row = meta["row"]
    col = meta["col"]


    # --- 新增：计算有效区域 ---
    # attention_mask shape: (1, H, W) -> (H, W)
    if attention_mask is not None:
        rmin, rmax, cmin, cmax = get_valid_range(attention_mask[0])
    else:
        # 如果没有 mask，默认全图
        rmin, rmax, cmin, cmax = 0, image_size, 0, image_size
    
    print(f"有效区域: rows [{rmin}:{rmax}], cols [{cmin}:{cmax}]")
    # ------------------------
    q_r = row - rmin
    q_c = col - cmin

    num_layers = len(all_attn)
    head_id = args.head  # 仅用于标题展示

    # 预计算 query index
    ph = image_size // patch_size
    pw = image_size // patch_size
    patch_row = row // patch_size
    patch_col = col // patch_size
    query_index = patch_row * pw + patch_col
    seq_query_index = num_task_tokens + query_index

    print(
        f"像素 (row={row}, col={col}) -> patch ({patch_row}, {patch_col}) -> "
        f"query token index={seq_query_index}"
    )

    # 原图转 RGB（用于叠底）
    input_img = inputs[0].numpy()

    # --- 新增：裁剪原图 ---
    input_img_cropped = input_img[rmin:rmax, cmin:cmax]
    # --------------------
    if targets is not None:
        target_img = targets[0].numpy()
        target_img_cropped = target_img[rmin:rmax, cmin:cmax]
    else:
        target_img_cropped = None


    cmap_img = plt.get_cmap("tab20")
    img_norm = input_img_cropped / max(input_img.max(), 1)
    base_rgb = cmap_img(img_norm)[:, :, :3]

    # 网格布局：默认 2 行 * 5 列，适合 depth=10
    n_rows = 2
    n_cols = (num_layers + n_rows - 1) // n_rows

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows))
    axes = axes.reshape(n_rows, n_cols)

    for layer_id in range(num_layers):
        attn = all_attn[layer_id]  # (B, H, N, N)
        if attn.size(1) == 0:
            raise RuntimeError(f"[Layer {layer_id}] 没有 head，形状: {attn.shape}")

        # 对所有 head 取平均 -> (N,)
        attn_layer = attn[0]  # (H, N, N)
        row_attn_all_heads = attn_layer[:, seq_query_index, :]  # (H, N)
        row_attn = row_attn_all_heads.mean(dim=0)  # (N,)

        pix_attn = row_attn[num_task_tokens:]  # (N_pix,)
        if pix_attn.numel() != ph * pw:
            raise RuntimeError(
                f"[Layer {layer_id}] 像素 token 数不匹配: pix_attn={pix_attn.numel()}, ph*pw={ph*pw}"
            )

        patch_heatmap = pix_attn.view(ph, pw).unsqueeze(0).unsqueeze(0)  # (1,1,ph,pw)
        heatmap = F.interpolate(
            patch_heatmap,
            size=(image_size, image_size),
            mode="bilinear",
            align_corners=False,
        )[0, 0]

        # --- 新增：裁剪 Heatmap ---
        heatmap = heatmap[rmin:rmax, cmin:cmax]
        # ------------------------

        heatmap = heatmap.clamp(min=0)
        heatmap = heatmap / (heatmap.max() + 1e-9)
        heatmap = torch.sqrt(heatmap)  # 稍微拉伸一下中小值
        heatmap_np = heatmap.numpy()

        r = layer_id // n_cols
        c = layer_id % n_cols
        ax = axes[r, c]

        if not args.only_heatmap:
            ax.imshow(base_rgb, interpolation="nearest")
            ax.imshow(
                heatmap_np,
                cmap="rainbow",
                alpha=args.alpha,
                interpolation="nearest",
            )
        else:
            ax.imshow(
                heatmap_np,
                cmap="rainbow",
                interpolation="nearest",
            )

        rect = patches.Rectangle(
            (q_c - 0.5, q_r - 0.5),
            1, 1,
            linewidth=1.5,
            edgecolor='cyan',
            facecolor='none',
        )
        ax.add_patch(rect)

        ax.set_title(f"Layer {layer_id}")
        ax.axis("off")

    # 把多余 subplot 去掉（如果 num_layers 不是 n_rows*n_cols）
    for idx in range(num_layers, n_rows * n_cols):
        r = idx // n_cols
        c = idx % n_cols
        axes[r, c].axis("off")

    task_name = meta.get("task_name", "")
    ex_idx = meta.get("example_index", 0)
    fig.suptitle(
        f"Task {task_name}, ex {ex_idx}, heads-avg, pixel ({row},{col})",
        fontsize=12,
    )

    out_path = Path(args.out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig.savefig(out_path, dpi=200)
    plt.close(fig)
    print(f"保存层间对比热力图到 {out_path}")


    # 额外画一张 Task 可视化：Input vs Target
    fig2, axes2 = plt.subplots(1, 2, figsize=(6, 3))

    cmap_img = plt.get_cmap("tab20")
    img_norm_in = input_img_cropped / max(input_img.max(), 1)
    axes2[0].imshow(cmap_img(img_norm_in)[:, :, :3], interpolation="nearest")
    axes2[0].set_title("Input (cropped)")
    axes2[0].axis("off")

    # 标出 query 像素
    rect_in = patches.Rectangle(
        (q_c - 0.5, q_r - 0.5),
        1, 1,
        linewidth=1.5,
        edgecolor="cyan",
        facecolor="none",
    )
    axes2[0].add_patch(rect_in)

    if target_img_cropped is not None:
        img_norm_tgt = target_img_cropped / max(target_img.max(), 1)
        axes2[1].imshow(cmap_img(img_norm_tgt)[:, :, :3], interpolation="nearest")
        axes2[1].set_title("Target (cropped)")
        axes2[1].axis("off")

        rect_t = patches.Rectangle(
            (q_c - 0.5, q_r - 0.5),
            1, 1,
            linewidth=1.5,
            edgecolor="cyan",
            facecolor="none",
        )
        axes2[1].add_patch(rect_t)
    else:
        axes2[1].axis("off")

    task_name = meta.get("task_name", "")
    ex_idx = meta.get("example_index", 0)
    fig2.suptitle(
        f"Task {task_name}, ex {ex_idx}, pixel ({row},{col})",
        fontsize=11,
    )

    task_vis_path = Path(args.out_path).with_name("task_input_target.png")
    task_vis_path.parent.mkdir(parents=True, exist_ok=True)
    fig2.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig2.savefig(task_vis_path, dpi=200)
    plt.close(fig2)
    print(f"保存 Task 可视化到 {task_vis_path}")


if __name__ == "__main__":
    main()