# batch_attention_maps_loop.py
import argparse
from pathlib import Path
from typing import List, Tuple

import torch

from src.ARC_loader import build_dataloaders
from src.attn_hook_loop import load_loop_vit_with_attn

def parse_args():
    p = argparse.ArgumentParser("Batch generate attention maps for LoopARCViT")
    p.add_argument("--ckpt", type=str, required=True)
    p.add_argument("--data-root", type=str, default="raw_data/ARC-AGI")
    p.add_argument("--task-list", type=str, required=True)
    p.add_argument("--out-root", type=str, default="attention_map_loop")
    p.add_argument("--image-size", type=int, default=64)
    p.add_argument("--num-colors", type=int, default=12)
    p.add_argument("--embed-dim", type=int, default=512)
    p.add_argument("--mlp-dim", type=int, default=512)
    p.add_argument("--loop-core-depth", type=int, default=2)
    p.add_argument("--max-loop-steps", type=int, default=6)
    p.add_argument("--min-loop-steps", type=int, default=1)
    p.add_argument("--num-heads", type=int, default=8)
    p.add_argument("--patch-size", type=int, default=2)
    p.add_argument("--num-task-tokens", type=int, default=1)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--max-examples-per-task", type=int, default=1)
    return p.parse_args()


def load_task_names(path: str) -> List[str]:
    with open(path, "r") as f:
        return [line.strip() for line in f if line.strip()]


def build_train_dataset(args):
    class DummyArgs: ...
    da_train = DummyArgs()
    da_train.data_root = args.data_root
    da_train.train_split = "training"
    da_train.eval_split = "training"
    da_train.eval_subset = "test"
    da_train.image_size = args.image_size
    da_train.num_colors = args.num_colors
    da_train.batch_size = 1
    da_train.num_workers = 0
    da_train.include_rearc = False
    da_train.rearc_path = "raw_data/re_arc"
    da_train.rearc_limit = -1
    da_train.patch_size = args.patch_size
    da_train.num_task_tokens = args.num_task_tokens
    da_train.fix_scale_factor = 2
    da_train.disable_translation = False
    da_train.disable_resolution_augmentation = False

    print("Loading training dataset to get num_tasks & task_lookup ...")
    train_dataset, *_ = build_dataloaders(
        da_train, distributed=False, rank=0, world_size=1
    )
    print("Done.")
    num_tasks = train_dataset.num_tasks
    task_lookup = getattr(train_dataset, "task_lookup", None)
    return train_dataset, num_tasks, task_lookup


def build_eval_dataset(args, task_name: str):
    class DummyArgs: ...
    da_eval = DummyArgs()
    da_eval.data_root = args.data_root
    da_eval.train_split = f"eval_color_permute_ttt_9/{task_name}"
    da_eval.eval_split = f"eval_color_permute_ttt_9/{task_name}"
    da_eval.eval_subset = "test"
    da_eval.image_size = args.image_size
    da_eval.num_colors = args.num_colors
    da_eval.batch_size = 1
    da_eval.num_workers = 0
    da_eval.include_rearc = False
    da_eval.rearc_path = "raw_data/re_arc"
    da_eval.rearc_limit = -1
    da_eval.patch_size = args.patch_size
    da_eval.num_task_tokens = args.num_task_tokens
    da_eval.fix_scale_factor = 2
    da_eval.disable_translation = False
    da_eval.disable_resolution_augmentation = False

    print(f"Loading eval TTT dataset for task {task_name} ...")
    eval_train_dataset, *_ = build_dataloaders(
        da_eval, distributed=False, rank=0, world_size=1
    )
    print(f"Done. #examples = {len(eval_train_dataset)}")
    return eval_train_dataset


def find_task_id(task_lookup: dict, task_name: str, fallback_id: int) -> int:
    if task_lookup is not None:
        if task_name in task_lookup:
            return task_lookup[task_name]
        candidates = [k for k in task_lookup.keys() if k.endswith(task_name)]
        if len(candidates) == 1:
            return task_lookup[candidates[0]]
    print(
        f"[Info] Task {task_name} 不在训练集 task_lookup 中，使用 fallback task_id={fallback_id}。"
    )
    return fallback_id


def select_representative_pixels_from_gt(
    inputs: torch.Tensor,
    outputs: torch.Tensor,
    max_pixels: int = 3,
    background_color: int = 0,
) -> List[Tuple[int, int]]:
    """
    仅根据 GT (outputs) 选有代表性的像素（被填色的 pixel）：
    1. 优先选前景像素：output != background_color；
    2. 如果前景点太多，按粗网格均匀采样几处；
    3. 如果没有前景（少见，比如全空或 copy 任务），退回到“非背景 input 像素”；
    4. 最后用中心点兜底。
    这样对同一个 task/example，在任何模型上，选到的 pixel 都是一致的。
    """
    inp = inputs[0]      # (H, W)
    tgt = outputs[0]     # (H, W)
    H, W = inp.shape

    selected: List[Tuple[int, int]] = []

    # 1. 前景像素：tgt != 背景色
    fg_mask = (tgt != background_color)
    fg_indices = torch.nonzero(fg_mask, as_tuple=False)  # (N, 2)

    if fg_indices.shape[0] > 0:
        num_grid = 3
        h_step = max(H // num_grid, 1)
        w_step = max(W // num_grid, 1)

        taken = 0
        for gy in range(num_grid):
            for gx in range(num_grid):
                if taken >= max_pixels:
                    break
                r_start = gy * h_step
                r_end = H if gy == num_grid - 1 else (gy + 1) * h_step
                c_start = gx * w_step
                c_end = W if gx == num_grid - 1 else (gx + 1) * w_step

                in_cell = fg_indices[
                    (fg_indices[:, 0] >= r_start)
                    & (fg_indices[:, 0] < r_end)
                    & (fg_indices[:, 1] >= c_start)
                    & (fg_indices[:, 1] < c_end)
                ]
                if in_cell.shape[0] > 0:
                    mid_idx = in_cell.shape[0] // 2
                    r, c = int(in_cell[mid_idx, 0]), int(in_cell[mid_idx, 1])
                    selected.append((r, c))
                    taken += 1
            if taken >= max_pixels:
                break

    # 2. 如果 GT 里没有前景（或不够），退到 input 的非背景物体像素
    if len(selected) < max_pixels:
        object_mask = (inp != background_color)
        obj_indices = torch.nonzero(object_mask, as_tuple=False)
        if obj_indices.shape[0] > 0:
            num_needed = max_pixels - len(selected)
            step = max(obj_indices.shape[0] // num_needed, 1)
            for i in range(0, obj_indices.shape[0], step):
                if len(selected) >= max_pixels:
                    break
                r, c = int(obj_indices[i, 0]), int(obj_indices[i, 1])
                if (r, c) not in selected:
                    selected.append((r, c))

    # 3. 仍然不够就用中心点兜底
    if len(selected) == 0:
        selected.append((H // 2, W // 2))

    # 去重并截断
    uniq: List[Tuple[int, int]] = []
    for p in selected:
        if p not in uniq:
            uniq.append(p)
    return uniq[:max_pixels]


def save_dump(
    out_path: Path,
    task_name: str,
    ex_idx: int,
    row: int,
    col: int,
    args,
    all_scores,
    all_attn,
    inputs,
    targets,
    task_ids,
    attention_mask,
    logits,
):
    dump = {
        "all_scores": [s.cpu() for s in all_scores],
        "all_attn": [a.cpu() for a in all_attn],
        "inputs": inputs.cpu(),
        "targets": targets.cpu(), 
        "task_ids": task_ids.cpu(),
        "attention_mask": attention_mask.cpu(),
        "logits": logits.cpu(),
        "meta": {
            "task_name": task_name,
            "example_index": ex_idx,
            "row": row,
            "col": col,
            "image_size": args.image_size,
            "patch_size": args.patch_size,
            "num_task_tokens": args.num_task_tokens,
        },
    }
    out_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(dump, out_path)
    print(f"Saved dump to {out_path}")


def run_visualizations_for_dump(dump_path: Path, out_dir: Path):
    import subprocess

    print(f"Drawing per-pixel layers grid (softmax) for {dump_path} ...")
    subprocess.run(
        [
            "python",
            "visualize_attention_layers_grid.py",
            "--dump-path",
            str(dump_path),
            "--out-path",
            str(out_dir / "pixel_layers_grid_AvgHead.png"),
        ],
        check=True,
    )

    print(f"Drawing layer-wise avg grid (softmax) for {dump_path} ...")
    subprocess.run(
        [
            "python",
            "visualize_attention_layerwise_grid.py",
            "--dump-path",
            str(dump_path),
            "--out-path",
            str(out_dir / "layerwise_grid.png"),
        ],
        check=True,
    )

    print(f"Drawing logits layers grid for {dump_path} ...")
    subprocess.run(
        [
            "python",
            "visualize_attention_logits_grid.py",
            "--dump-path",
            str(dump_path),
            "--out-path",
            str(out_dir / "attn_logits_layers_grid_AvgHead.png"),
        ],
        check=True,
    )


def main():
    args = parse_args()
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    train_dataset, num_tasks, task_lookup = build_train_dataset(args)

    fallback_task_id = 0
    model = load_loop_vit_with_attn(
        ckpt_path=args.ckpt,
        num_tasks=num_tasks,
        image_size=args.image_size,
        num_colors=args.num_colors,
        embed_dim=args.embed_dim,
        loop_core_depth=args.loop_core_depth,
        max_loop_steps=args.max_loop_steps,
        min_loop_steps=args.min_loop_steps,
        num_heads=args.num_heads,
        mlp_dim=args.mlp_dim,
        dropout=0.1,
        num_task_tokens=args.num_task_tokens,
        patch_size=args.patch_size,
        device=device.type,
    )

    model.eval()

    task_names = load_task_names(args.task_list)
    out_root = Path(args.out_root)

    for task_name in task_names:
        print(f"\n===== Processing task {task_name} (Loop) =====")
        task_id = find_task_id(task_lookup, task_name, fallback_task_id)

        eval_train_dataset = build_eval_dataset(args, task_name)
        if len(eval_train_dataset) == 0:
            print(f"[Skip task {task_name}]: no examples in eval split")
            continue

        task_out_dir = out_root / task_name
        task_out_dir.mkdir(parents=True, exist_ok=True)

        num_ex = min(args.max_examples_per_task, len(eval_train_dataset))
        for ex_idx in range(num_ex):
            sample = eval_train_dataset[ex_idx]
            inputs = sample["inputs"].unsqueeze(0).to(device)
            attention_mask = sample["attention_mask"].unsqueeze(0).to(device)
            task_ids = torch.full(
                (1,), fill_value=task_id, dtype=torch.long, device=device
            )

            with torch.no_grad():
                logits, (all_scores, all_attn, metadata) = model(
                    inputs,
                    task_ids,
                    attention_mask=attention_mask,
                    dynamic_exit=False,
                    return_attn=True,
                )
            assert all_attn is not None

            targets = sample["targets"].unsqueeze(0).to(device)

            selected_pixels = select_representative_pixels_from_gt(
                inputs.cpu(), targets.cpu(), max_pixels=3, background_color=0
            )
            print(
                f"Task {task_name}, ex {ex_idx}: selected {len(selected_pixels)} pixels: "
                f"{selected_pixels}"
            )

            for (row, col) in selected_pixels:
                dump_path = task_out_dir / f"attn_ex{ex_idx}_r{row}_c{col}.pt"
                save_dump(
                    dump_path,
                    task_name,
                    ex_idx,
                    row,
                    col,
                    args,
                    all_scores,
                    all_attn,
                    inputs,
                    targets,
                    task_ids,
                    attention_mask,
                    logits,
                )

                ex_out_dir = task_out_dir / f"ex{ex_idx}_r{row}_c{col}"
                ex_out_dir.mkdir(parents=True, exist_ok=True)
                run_visualizations_for_dump(dump_path, ex_out_dir)

    print("All Loop tasks finished.")


if __name__ == "__main__":
    main()