import torch
from thop import profile, clever_format

from src.ARC_ViT import ARCViT
from src.ARC_LoopViT import LoopARCViT


def build_varc_vit(
    image_size: int = 64,
    num_colors: int = 12,
    embed_dim: int = 512,
    depth: int = 10,
    num_heads: int = 8,
    mlp_dim: int = 1024,
    patch_size: int = 2,
    num_tasks: int = 400,
    num_task_tokens: int = 1,
) -> torch.nn.Module:
    model = ARCViT(
        num_tasks=num_tasks,
        image_size=image_size,
        num_colors=num_colors,
        embed_dim=embed_dim,
        depth=depth,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout=0.1,
        num_task_tokens=num_task_tokens,
        patch_size=patch_size,
    )
    return model


def build_loop_vit(
    image_size: int = 64,
    num_colors: int = 12,
    embed_dim: int = 512,
    loop_core_depth: int = 2,
    max_loop_steps: int = 6,
    min_loop_steps: int = 6,
    num_heads: int = 8,
    mlp_dim: int = 1024,
    patch_size: int = 2,
    num_tasks: int = 400,
    num_task_tokens: int = 1,
    use_exit_gate: bool = False,
) -> torch.nn.Module:
    model = LoopARCViT(
        num_tasks=num_tasks,
        image_size=image_size,
        num_colors=num_colors,
        embed_dim=embed_dim,
        loop_core_depth=loop_core_depth,
        max_loop_steps=max_loop_steps,
        min_loop_steps=min_loop_steps,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout=0.1,
        num_task_tokens=num_task_tokens,
        patch_size=patch_size,
        use_exit_gate=use_exit_gate,
        gate_threshold=0.6,
        add_step_embeddings=True,
    )
    return model


def main() -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 配置与 VARC 训练一致
    image_size = 64
    num_colors = 12
    embed_dim = 512
    num_heads = 8
    mlp_dim = 1024
    patch_size = 2
    num_tasks = 400
    num_task_tokens = 1

    # 单样本输入：形状和 ARC loader 一致 (B, H, W)，值范围 0..num_colors-1
    batch_size = 1
    dummy_pixels = torch.randint(
        low=0,
        high=num_colors,
        size=(batch_size, image_size, image_size),
        dtype=torch.long,
    ).to(device)
    # 随便选一个 task id
    dummy_task_ids = torch.zeros(batch_size, dtype=torch.long, device=device)
    # 全 1 的 attention_mask（完整画布）
    dummy_attention = torch.ones(
        batch_size, image_size, image_size, dtype=torch.long, device=device
    )

    # ========= VARC-ViT =========
    vit = build_varc_vit(
        image_size=image_size,
        num_colors=num_colors,
        embed_dim=embed_dim,
        depth=10,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        patch_size=patch_size,
        num_tasks=num_tasks,
        num_task_tokens=num_task_tokens,
    ).to(device)
    vit.eval()

    with torch.no_grad():
        macs_vit, params_vit = profile(
            vit,
            inputs=(dummy_pixels, dummy_task_ids, dummy_attention),
            verbose=False,
        )
    flops_vit = 2 * macs_vit  # 通常 1 MAC ≈ 2 FLOPs
    macs_vit_fmt, params_vit_fmt = clever_format([macs_vit, params_vit], "%.3f")
    flops_vit_fmt, _ = clever_format([flops_vit, params_vit], "%.3f")

    print("=== VARC-ViT (baseline) ===")
    print(f"MACs:  {macs_vit_fmt}")
    print(f"FLOPs: {flops_vit_fmt}")
    print(f"Params:{params_vit_fmt}")
    print()

    # ========= LoopARCViT (fixed 6 steps, no gate) =========
    loop_vit_6 = build_loop_vit(
        image_size=image_size,
        num_colors=num_colors,
        embed_dim=embed_dim,
        loop_core_depth=2,
        max_loop_steps=6,
        min_loop_steps=6,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        patch_size=patch_size,
        num_tasks=num_tasks,
        num_task_tokens=num_task_tokens,
        use_exit_gate=False,  # 不启用 gate，只是固定步数
    ).to(device)
    loop_vit_6.eval()

    class LoopVitWrapper(torch.nn.Module):
        """Wrap LoopARCViT to fix dynamic_exit=False and return only logits."""

        def __init__(self, core: torch.nn.Module) -> None:
            super().__init__()
            self.core = core

        def forward(self, pixel_values, task_ids, attention_mask):
            logits, _ = self.core(
                pixel_values,
                task_ids,
                attention_mask=attention_mask,
                dynamic_exit=False,   # 固定为不动态退出
                gate_threshold=None,  # 使用默认 threshold（无效，因为 use_exit_gate=False）
            )
            return logits

    wrapped_loop_6 = LoopVitWrapper(loop_vit_6).to(device)
    wrapped_loop_6.eval()

    with torch.no_grad():
        macs_loop_6, params_loop_6 = profile(
            wrapped_loop_6,
            inputs=(dummy_pixels, dummy_task_ids, dummy_attention),
            verbose=False,
        )

    flops_loop_6 = 2 * macs_loop_6
    macs_loop_6_fmt, params_loop_6_fmt = clever_format(
        [macs_loop_6, params_loop_6], "%.3f"
    )
    flops_loop_6_fmt, _ = clever_format([flops_loop_6, params_loop_6], "%.3f")

    print("=== LoopARCViT (core_depth=2, steps=6, no gate) ===")
    print(f"MACs:  {macs_loop_6_fmt}")
    print(f"FLOPs: {flops_loop_6_fmt}")
    print(f"Params:{params_loop_6_fmt}")
    print()

    # ========= LoopARCViT (fixed 1 step, no gate) =========
    loop_vit_1 = build_loop_vit(
        image_size=image_size,
        num_colors=num_colors,
        embed_dim=embed_dim,
        loop_core_depth=2,
        max_loop_steps=1,
        min_loop_steps=1,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        patch_size=patch_size,
        num_tasks=num_tasks,
        num_task_tokens=num_task_tokens,
        use_exit_gate=False,
    ).to(device)
    loop_vit_1.eval()

    wrapped_loop_1 = LoopVitWrapper(loop_vit_1).to(device)
    wrapped_loop_1.eval()

    with torch.no_grad():
        macs_loop_1, params_loop_1 = profile(
            wrapped_loop_1,
            inputs=(dummy_pixels, dummy_task_ids, dummy_attention),
            verbose=False,
        )

    flops_loop_1 = 2 * macs_loop_1
    macs_loop_1_fmt, params_loop_1_fmt = clever_format(
        [macs_loop_1, params_loop_1], "%.3f"
    )
    flops_loop_1_fmt, _ = clever_format([flops_loop_1, params_loop_1], "%.3f")

    print("=== LoopARCViT (core_depth=2, steps=1, no gate) ===")
    print(f"MACs:  {macs_loop_1_fmt}")
    print(f"FLOPs: {flops_loop_1_fmt}")
    print(f"Params:{params_loop_1_fmt}")


if __name__ == "__main__":
    main()