import argparse
from pathlib import Path
from typing import Dict, Any, List

import torch
import torch.nn.functional as F

from src.ARC_loader import IGNORE_INDEX, build_dataloaders
from src.ARC_LoopViT import LoopARCViT
from utils.args import add_resume_checkpoints
from utils.distribution import init_distributed_mode


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Minimal loop trajectory visualization for LoopARCViT")
    # 复用 resume 参数，主要为了指定 checkpoint 路径
    add_resume_checkpoints(parser)

    parser.add_argument("--data-root", type=str, default="raw_data/ARC-AGI")
    parser.add_argument("--eval-split", type=str, default="evaluation")
    parser.add_argument(
        "--eval-subset",
        type=str,
        choices=("train", "test"),
        default="test",
        help="Which part of ARC split to use inside each task.",
    )
    parser.add_argument("--image-size", type=int, default=64)
    parser.add_argument("--num-colors", type=int, default=12)
    parser.add_argument("--embed-dim", type=int, default=512)
    parser.add_argument("--mlp-dim", type=int, default=1024)
    parser.add_argument("--num-heads", type=int, default=8)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--patch-size", type=int, default=2)
    parser.add_argument("--num-task-tokens", type=int, default=1)
    parser.add_argument("--fix-scale-factor", type=int, default=2)

    # Loop-specific
    parser.add_argument("--loop-core-depth", type=int, default=2)
    parser.add_argument("--max-loop-steps", type=int, default=8)
    parser.add_argument("--min-loop-steps", type=int, default=2)
    parser.add_argument("--disable-exit-gate", action="store_true")
    parser.add_argument("--exit-gate-threshold", type=float, default=0.6)
    parser.add_argument("--no-step-embedding", action="store_true")

    parser.add_argument(
        "--steps-to-probe",
        type=int,
        nargs="+",
        default=None,
        help="List of effective loop steps to probe, e.g. --steps-to-probe 1 2 4 8. "
             "If None, will use [1, max_loop_steps].",
    )
    parser.add_argument(
        "--num-batches",
        type=int,
        default=1,
        help="How many eval loader batches to analyze (small for quick run).",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Path to a trained LoopARC checkpoint (the same format as offline_train_loop_ARC.py best-save-path).",
    )
    return parser


def build_model_from_checkpoint(args: argparse.Namespace, device: torch.device) -> LoopARCViT:
    ckpt_path = Path(args.checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    checkpoint = torch.load(ckpt_path, map_location=device)
    config: Dict[str, Any] = checkpoint.get("config", {})

    # 使用 checkpoint 中的 config 覆盖关键超参（如果存在），否则使用命令行默认
    image_size = config.get("image_size", args.image_size)
    num_colors = config.get("num_colors", args.num_colors)
    embed_dim = config.get("embed_dim", args.embed_dim)
    mlp_dim = config.get("mlp_dim", args.mlp_dim)
    num_heads = config.get("num_heads", args.num_heads)
    dropout = config.get("dropout", args.dropout)
    loop_core_depth = config.get("loop_core_depth", args.loop_core_depth)
    max_loop_steps = config.get("max_loop_steps", args.max_loop_steps)
    min_loop_steps = config.get("min_loop_steps", args.min_loop_steps)
    num_task_tokens = config.get("num_task_tokens", args.num_task_tokens)
    patch_size = config.get("patch_size", args.patch_size)
    use_exit_gate = not config.get("disable_exit_gate", args.disable_exit_gate)
    gate_threshold = config.get("exit_gate_threshold", args.exit_gate_threshold)
    add_step_embeddings = not config.get("no_step_embedding", args.no_step_embedding)

    # 注意：num_tasks 只能从数据集推断，这里先占位，后面真正构建时再用 dataset.num_tasks
    dummy_num_tasks = 1
    model = LoopARCViT(
        num_tasks=dummy_num_tasks,
        image_size=image_size,
        num_colors=num_colors,
        embed_dim=embed_dim,
        loop_core_depth=loop_core_depth,
        max_loop_steps=max_loop_steps,
        min_loop_steps=min_loop_steps,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout=dropout,
        num_task_tokens=num_task_tokens,
        patch_size=patch_size,
        use_exit_gate=use_exit_gate,
        gate_threshold=gate_threshold,
        add_step_embeddings=add_step_embeddings,
    )

    # 这里直接加载 state_dict，不考虑 DDP / compile 前缀；因为我们打算单卡、非 DDP 运行。
    state_dict = checkpoint.get("model_state", {})
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    if missing:
        print(f"[Warning] Missing keys when loading state_dict: {len(missing)}")
    if unexpected:
        print(f"[Warning] Unexpected keys when loading state_dict: {len(unexpected)}")

    model.to(device)
    model.eval()
    return model


def collect_batch_examples(loader, num_batches: int):
    """从 eval loader 里收集少量 batch，返回列表。"""
    batches = []
    for i, batch in enumerate(loader):
        batches.append(batch)
        if len(batches) >= num_batches:
            break
    return batches


def evaluate_step_effect(
    model: LoopARCViT,
    batches: List[Dict[str, torch.Tensor]],
    device: torch.device,
    steps_to_probe: List[int],
):
    """
    “最小版” loop 轨迹：对不同的假想步数 s，多次前向，比较准确率变化。
    注意：这里没有真正逐步截断 core_layers，只是粗糙地把 max_loop_steps 当成整体迭代多次，
    主要用于给你一个快速可跑的骨架。
    """
    # 这里读取模型的配置
    max_steps = model.max_loop_steps

    # 如果 steps_to_probe 为空，就简单用 [1, max_steps]
    if not steps_to_probe:
        steps_to_probe = [1, max_steps]
    steps_to_probe = sorted(set(steps_to_probe))

    # 对每个样本，记录在每个 s 下是否正确
    per_sample_first_correct_step: Dict[int, int] = {}

    # 简单假设：多跑几次前向，当作“更多 loop 步数”近似（严格来说这不对，但不改模型的前提下很难做到精确）。
    for s in steps_to_probe:
        print(f"\n=== Probing effective steps: {s} (approximate) ===")
        total_exact = 0
        total_examples = 0

        with torch.no_grad():
            for batch in batches:
                inputs = batch["inputs"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                targets = batch["targets"].to(device)
                task_ids = batch["task_ids"].to(device)

                # 近似“s 步”：我们就连续跑 s 次 forward，每次把上一次的 logits 视作下一次的输入颜色。
                # 这是一个非常粗糙的 surrogate，只是为了展示脚本结构，你后续可以改成真正的 per-step hook。
                current_inputs = inputs.clone()
                for _ in range(s):
                    logits, metadata = model(
                        current_inputs,
                        task_ids,
                        attention_mask=attention_mask,
                        dynamic_exit=False,
                    )
                    # 把 logits argmax 回到离散颜色，作为下一轮输入
                    pred_colors = logits.argmax(dim=1)  # (B, H, W)
                    current_inputs = pred_colors

                predictions = logits.argmax(dim=1)
                batch_size = predictions.size(0)
                for idx in range(batch_size):
                    target = targets[idx]
                    prediction = predictions[idx]
                    valid = target != IGNORE_INDEX
                    is_exact = bool(torch.equal(prediction[valid], target[valid])) if valid.any() else False

                    global_idx = total_examples + idx
                    if is_exact and global_idx not in per_sample_first_correct_step:
                        per_sample_first_correct_step[global_idx] = s

                    total_exact += int(is_exact)

                total_examples += batch_size

        acc = total_exact / max(total_examples, 1)
        print(f"Approx accuracy at steps={s}: {acc:.4f}")

    # 打印一些样本级轨迹信息
    print("\n=== Sample-wise first-correct-step summary (approximate) ===")
    if not per_sample_first_correct_step:
        print("No sample became correct at any probed step.")
    else:
        # 统计每个步数有多少样本第一次变对
        hist: Dict[int, int] = {}
        for _, s in per_sample_first_correct_step.items():
            hist[s] = hist.get(s, 0) + 1
        for s in sorted(hist.keys()):
            print(f"First correct at step {s}: {hist[s]} samples")


def main():
    parser = build_parser()
    args = parser.parse_args()

    # 简单单机 / 单卡设置
    distributed, rank, world_size, local_rank, device = init_distributed_mode(args)
    if distributed and world_size > 1:
        raise RuntimeError("This minimal visualization script is intended for single process / single GPU only.")

    # 只需要 eval loader
    _, _, eval_dataset, eval_loader, _, eval_sampler = build_dataloaders(
        args,
        distributed=False,
        rank=0,
        world_size=1,
    )
    if eval_dataset is None or eval_loader is None:
        raise RuntimeError("Evaluation dataset/loader is None. Please check your data-root / eval-split settings.")

    # 现在我们知道 num_tasks，可以重建模型并加载 checkpoint
    model = build_model_from_checkpoint(args, device)
    model.num_task_tokens = args.num_task_tokens  # 保持一致（如果必要）

    # 采样少量 batch
    batches = collect_batch_examples(eval_loader, num_batches=args.num_batches)
    if not batches:
        raise RuntimeError("No batches found in eval loader.")

    steps_to_probe = args.steps_to_probe
    evaluate_step_effect(model, batches, device, steps_to_probe)


if __name__ == "__main__":
    main()