# visualize_attention_logits_grid.py
import argparse
from pathlib import Path

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


def parse_args():
    p = argparse.ArgumentParser("Visualize per-pixel logits across layers (grid)")
    p.add_argument("--dump-path", type=str, required=True)
    p.add_argument("--head", type=int, default=0)
    p.add_argument("--out-path", type=str, default="attn_logits_layers_grid.png")
    p.add_argument("--alpha", type=float, default=0.6)
    p.add_argument("--only-heatmap", action="store_true")
    return p.parse_args()

def get_valid_range(mask):
    """
    从 attention_mask (H, W) 中提取有效区域的边界。
    假设 mask 中 1 (True) 代表有效，0 (False) 代表 Padding。
    """
    if isinstance(mask, torch.Tensor):
        mask = mask.cpu().numpy()
    
    # 找到所有非零元素的索引
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    
    if not np.any(rows) or not np.any(cols):
        # 如果全是 0 (异常情况)，返回原图大小
        return 0, mask.shape[0], 0, mask.shape[1]

    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    return rmin, rmax + 1, cmin, cmax + 1


def main():
    args = parse_args()
    dump = torch.load(args.dump_path, map_location="cpu")

    all_scores = dump["all_scores"]      # List[Tensor[B, H, N, N]]
    inputs = dump["inputs"]
    attention_mask = dump.get("attention_mask") 
    meta = dump["meta"]

    image_size = meta["image_size"]
    patch_size = meta["patch_size"]
    num_task_tokens = meta["num_task_tokens"]
    row = meta["row"]
    col = meta["col"]

    # --- 新增：计算有效区域 ---
    # attention_mask shape: (1, H, W) -> (H, W)
    if attention_mask is not None:
        rmin, rmax, cmin, cmax = get_valid_range(attention_mask[0])
    else:
        # 如果没有 mask，默认全图
        rmin, rmax, cmin, cmax = 0, image_size, 0, image_size
    
    print(f"有效区域: rows [{rmin}:{rmax}], cols [{cmin}:{cmax}]")
    # ------------------------

    num_layers = len(all_scores)
    head_id = args.head  # 仅用于标题展示，不再实际选取某个 head

    ph = image_size // patch_size
    pw = image_size // patch_size
    patch_row = row // patch_size
    patch_col = col // patch_size
    query_index = patch_row * pw + patch_col
    seq_query_index = num_task_tokens + query_index

    print(
        f"像素 (row={row}, col={col}) -> patch ({patch_row}, {patch_col}) -> "
        f"query token index={seq_query_index}"
    )

    input_img = inputs[0].numpy()
    # --- 新增：裁剪原图 ---
    input_img_cropped = input_img[rmin:rmax, cmin:cmax]
    # --------------------

    cmap_img = plt.get_cmap("tab20")
    img_norm = input_img_cropped / max(input_img.max(), 1)
    base_rgb = cmap_img(img_norm)[:, :, :3]

    n_rows = 2
    n_cols = (num_layers + n_rows - 1) // n_rows

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows))
    axes = axes.reshape(n_rows, n_cols)

    for layer_id in range(num_layers):
        scores = all_scores[layer_id]  # (B,H,N,N)
        
        # 对所有 head 取平均 -> (N,)
        scores_layer = scores[0]  # (H, N, N)
        row_scores_all_heads = scores_layer[:, seq_query_index, :]  # (H, N)
        row_scores = row_scores_all_heads.mean(dim=0)  # (N,)

        pix_scores = row_scores[num_task_tokens:]
        if pix_scores.numel() != ph * pw:
            raise RuntimeError(
                f"[Layer {layer_id}] 像素 token 数不匹配: "
                f"pix_scores={pix_scores.numel()}, ph*pw={ph*pw}"
            )

        patch_map = pix_scores.view(ph, pw).unsqueeze(0).unsqueeze(0)  # (1,1,ph,pw)
        patch_map = torch.tanh(patch_map / (patch_map.std() + 1e-6))

        heatmap = F.interpolate(
            patch_map,
            size=(image_size, image_size),
            mode="bilinear",
            align_corners=False,
        )[0, 0]

        # --- 新增：裁剪 Heatmap ---
        heatmap = heatmap[rmin:rmax, cmin:cmax]
        # ------------------------

        heatmap = heatmap - heatmap.min()
        heatmap = heatmap / (heatmap.max() + 1e-9)
        heatmap_np = heatmap.numpy()

        r = layer_id // n_cols
        c = layer_id % n_cols
        ax = axes[r, c]

        if not args.only_heatmap:
            ax.imshow(base_rgb, interpolation="nearest")
            ax.imshow(
                heatmap_np,
                cmap="rainbow",
                alpha=args.alpha,
                interpolation="nearest",
            )
        else:
            ax.imshow(
                heatmap_np,
                cmap="rainbow",
                interpolation="nearest",
            )

        ax.set_title(f"Layer {layer_id}")
        ax.axis("off")

    for idx in range(num_layers, n_rows * n_cols):
        r = idx // n_cols
        c = idx % n_cols
        axes[r, c].axis("off")

    task_name = meta.get("task_name", "")
    ex_idx = meta.get("example_index", 0)
    fig.suptitle(
        f"Logits across layers (heads-avg) | Task {task_name}, ex {ex_idx}, head {head_id}, pixel ({row},{col})",
        fontsize=12,
    )

    out_path = Path(args.out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig.savefig(out_path, dpi=200)
    plt.close(fig)
    print(f"保存 logits 层间对比图到 {out_path}")

if __name__ == "__main__":
    main()