import argparse
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.amp import GradScaler, autocast
from torch.utils.data import DataLoader

from offline_train_loop_ARC import (
    _compute_gate_regularizers,
    _format_eta,
    evaluate,
    set_seed,
)
from src.ARC_LoopViT import LoopARCViT
from src.ARC_loader import IGNORE_INDEX, build_dataloaders
from utils.args import (
    add_resume_checkpoints,
    add_speed_optimizer_args,
    add_wandb_args,
)
from utils.distribution import init_distributed_mode
from utils.load_model import count_parameters
from utils.lr_scheduler import get_cosine_schedule_with_warmup

try:
    import wandb
except ImportError:  # pragma: no cover - optional dependency
    wandb = None


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Second-stage gate-only fine-tuning for Loop-ARC ViT")
    add_resume_checkpoints(parser)
    add_wandb_args(parser)
    add_speed_optimizer_args(parser)

    parser.add_argument("--teacher-checkpoint", type=str, required=True, help="Path to stage-one checkpoint used for initialization.")
    parser.add_argument("--data-root", type=str, default="raw_data/ARC-AGI")
    parser.add_argument("--train-split", type=str, default="training")
    parser.add_argument("--eval-split", type=str, default="training")
    parser.add_argument("--eval-subset", type=str, choices=("train", "test"), default="test")
    parser.add_argument("--image-size", type=int, default=64)
    parser.add_argument("--num-colors", type=int, default=12)
    parser.add_argument("--embed-dim", type=int, default=512)
    parser.add_argument("--mlp-dim", type=int, default=1024)
    parser.add_argument("--num-heads", type=int, default=8)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=15)
    parser.add_argument("--learning-rate", type=float, default=1e-4)
    parser.add_argument("--weight-decay", type=float, default=0.0)
    parser.add_argument("--max-grad-norm", type=float, default=1.0)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--seed", type=int, default=1337)
    parser.add_argument("--save-path", type=str, default=None)
    parser.add_argument("--best-save-path", type=str, default=None)
    parser.add_argument("--lr-scheduler", type=str, choices=("none", "cosine"), default="cosine")
    parser.add_argument("--vis-every", type=int, default=10)

    parser.add_argument("--loop-core-depth", type=int, default=2)
    parser.add_argument("--max-loop-steps", type=int, default=8)
    parser.add_argument("--min-loop-steps", type=int, default=2)
    parser.add_argument("--exit-gate-threshold", type=float, default=0.6)
    parser.add_argument("--gate-entropy-weight", type=float, default=5e-4)
    parser.add_argument("--loop-penalty-weight", type=float, default=1e-3)
    parser.add_argument("--train-dynamic-exit", dest="train_dynamic_exit", action="store_true", help="Enable dynamic exit decisions during gate tuning.")
    parser.add_argument("--no-train-dynamic-exit", dest="train_dynamic_exit", action="store_false", help="Disable dynamic exit during gate tuning.")
    parser.add_argument("--eval-dynamic-exit", dest="eval_dynamic_exit", action="store_true", help="Use dynamic exit when evaluating.")
    parser.add_argument("--no-eval-dynamic-exit", dest="eval_dynamic_exit", action="store_false", help="Force fixed max steps when evaluating.")
    parser.add_argument("--disable-exit-gate", action="store_true", help="Turn off the exit gate head (not compatible with gate fine-tuning).")
    parser.add_argument("--no-step-embedding", action="store_true")
    parser.add_argument("--tune-step-embeddings", action="store_true", help="Allow step embeddings to update during gate fine-tuning.")

    parser.add_argument("--include-rearc", action="store_true")
    parser.add_argument("--rearc-path", type=str, default="raw_data/re_arc")
    parser.add_argument("--rearc-limit", type=int, default=-1)
    parser.add_argument("--distributed", action="store_true")
    parser.add_argument("--patch-size", type=int, default=2)
    parser.add_argument("--num-task-tokens", type=int, default=1)
    parser.add_argument("--fix-scale-factor", type=int, default=2)
    parser.add_argument("--disable-translation", action="store_true")
    parser.add_argument("--disable-resolution-augmentation", action="store_true")

    # parser.add_argument("--use-wandb", action="store_true")
    parser.set_defaults(train_dynamic_exit=True, eval_dynamic_exit=True)
    return parser


def _instantiate_loop_model(args, train_dataset) -> LoopARCViT:
    return LoopARCViT(
        num_tasks=train_dataset.num_tasks,
        image_size=args.image_size,
        num_colors=args.num_colors,
        embed_dim=args.embed_dim,
        loop_core_depth=args.loop_core_depth,
        max_loop_steps=args.max_loop_steps,
        min_loop_steps=args.min_loop_steps,
        num_heads=args.num_heads,
        mlp_dim=args.mlp_dim,
        dropout=args.dropout,
        num_task_tokens=args.num_task_tokens,
        patch_size=args.patch_size,
        use_exit_gate=True,
        gate_threshold=args.exit_gate_threshold,
        add_step_embeddings=not args.no_step_embedding,
    )


def _load_state_dict(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor], args) -> None:
    cleaned = {key.replace("_orig_mod.", "", 1): value for key, value in state_dict.items()}
    if args.resume_skip_task_token and "task_token_embed.weight" in cleaned:
        cleaned = {k: v for k, v in cleaned.items() if k != "task_token_embed.weight"}
    missing, unexpected = model.load_state_dict(cleaned, strict=False)
    if missing:
        print(f"Warning: missing keys when loading checkpoint: {sorted(missing)}")
    if unexpected:
        print(f"Warning: unexpected keys ignored from checkpoint: {sorted(unexpected)}")


def _load_teacher_weights(model: torch.nn.Module, args, device: torch.device) -> Dict[str, Any]:
    ckpt_path = Path(args.teacher_checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Teacher checkpoint not found: {ckpt_path}")
    checkpoint = torch.load(ckpt_path, map_location=device)
    state_dict = checkpoint.get("model_state", checkpoint)
    _load_state_dict(model, state_dict, args)
    return checkpoint


def _maybe_resume_training(model: torch.nn.Module, args, device: torch.device) -> Optional[Dict[str, Any]]:
    if not args.resume_checkpoint:
        return None
    ckpt_path = Path(args.resume_checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Resume checkpoint not found: {ckpt_path}")
    checkpoint = torch.load(ckpt_path, map_location=device)
    state_dict = checkpoint.get("model_state", checkpoint)
    _load_state_dict(model, state_dict, args)
    return checkpoint


def _configure_trainable_parameters(model: LoopARCViT, tune_step_embeddings: bool) -> List[torch.nn.Parameter]:
    for param in model.parameters():
        param.requires_grad = False
    if model.exit_gate is None:
        raise RuntimeError("LoopARCViT was instantiated without an exit gate.")
    trainable: List[torch.nn.Parameter] = []
    for param in model.exit_gate.parameters():
        param.requires_grad = True
        trainable.append(param)
    if tune_step_embeddings and model.step_embed is not None:
        for param in model.step_embed.parameters():
            param.requires_grad = True
            trainable.append(param)
    return trainable


def _build_optimizer(model: torch.nn.Module, learning_rate: float, weight_decay: float) -> torch.optim.Optimizer:
    params = [p for p in model.parameters() if p.requires_grad]
    if not params:
        raise RuntimeError("No parameters left trainable for gate fine-tuning.")
    return torch.optim.AdamW(params, lr=learning_rate, weight_decay=weight_decay)


def evaluate_gate(model, loader, device, args, *, distributed: bool = False):
    return evaluate(model, loader, device, args, distributed=distributed)


def train(args: argparse.Namespace) -> None:
    if args.disable_exit_gate:
        raise ValueError("Gate-only fine-tuning requires an active exit gate; remove --disable-exit-gate.")
    distributed, rank, world_size, local_rank, device = init_distributed_mode(args)
    set_seed(args.seed + (rank if distributed else 0))

    train_dataset, train_loader, eval_dataset, eval_loader, train_sampler, eval_sampler = build_dataloaders(
        args,
        distributed=distributed,
        rank=rank,
        world_size=world_size,
    )

    if args.disable_translation:
        train_dataset.disable_translation()
        if eval_dataset is not None:
            eval_dataset.disable_translation()
    else:
        train_dataset.enable_translation()
        if eval_dataset is not None:
            eval_dataset.enable_translation()

    if args.disable_resolution_augmentation:
        train_dataset.disable_resolution_augmentation(fix_scale_factor=args.fix_scale_factor)
        if eval_dataset is not None:
            eval_dataset.disable_resolution_augmentation(fix_scale_factor=args.fix_scale_factor)
    else:
        train_dataset.enable_resolution_augmentation()
        if eval_dataset is not None:
            eval_dataset.enable_resolution_augmentation()

    if (not distributed) or rank == 0:
        print(f"Total training examples: {len(train_dataset)}")

    base_model = _instantiate_loop_model(args, train_dataset)
    base_model.to(device)

    _load_teacher_weights(base_model, args, device)
    resume_payload = _maybe_resume_training(base_model, args, device)

    # Configure and log trainable parameters (gate + optional step embeddings).
    trainable_params = _configure_trainable_parameters(base_model, args.tune_step_embeddings)
    if (not distributed) or rank == 0:
        total_trainable = sum(p.numel() for p in trainable_params)
        print(f"[DEBUG] gate stage trainable tensors: {len(trainable_params)}")
        print(f"[DEBUG] gate stage trainable param count: {total_trainable}")

    if not args.no_compile and hasattr(torch, "compile"):
        if (not distributed) or rank == 0:
            print("Applying torch.compile for gate fine-tuning...")
        base_model = torch.compile(base_model, mode=args.compile_mode)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            base_model,
            device_ids=[local_rank] if device.type == "cuda" else None,
            output_device=local_rank if device.type == "cuda" else None,
        )
    else:
        model = base_model
    model_for_eval = model.module if distributed else model

    optimizer = _build_optimizer(model, args.learning_rate, args.weight_decay)
    if (not distributed) or rank == 0:
        print(f"[DEBUG] initial lr from args: {args.learning_rate}")
        print(f"[DEBUG] optimizer param groups: {len(optimizer.param_groups)}")
        if optimizer.param_groups:
            print(f"[DEBUG] optimizer group[0].lr: {optimizer.param_groups[0]['lr']}")

    scaler = GradScaler(enabled=(device.type == "cuda" and not args.no_amp))
    scheduler = None  # Gate-only 阶段不使用 scheduler，保持常数学习率

    # Gate stage: do NOT restore optimizer or scheduler state from checkpoints.
    # We only optionally resume the epoch counter so training can continue,
    # but learning rate schedule always starts fresh for gate fine-tuning.
    start_epoch = 1
    if resume_payload is not None and (not args.resume_reset_epoch) and "epoch" in resume_payload:
        start_epoch = resume_payload["epoch"] + 1

    if (not distributed) or rank == 0:
        print(
            f"Trainable parameters (gate stage): {count_parameters(model_for_eval) / 1_000_000:.4f}M"
        )
        if scaler.is_enabled():
            print("Using AMP training")
        else:
            print("AMP disabled")

    wandb_run = None
    is_main_process = (not distributed) or rank == 0
    if args.use_wandb and is_main_process:
        if wandb is None:
            raise RuntimeError("Weights & Biases is not installed but --use-wandb was set.")
        wandb_config = dict(vars(args))
        wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_run_name or None, config=wandb_config)
        wandb.watch(model_for_eval, log=None)

    best_eval_acc = float("-inf")
    global_start = time.time()
    previous_total_steps = 0

    try:
        for epoch in range(start_epoch, args.epochs + 1):
            if train_sampler is not None:
                train_sampler.set_epoch(epoch)
            model.train()
            running_loss = 0.0
            sample_count = 0
            train_exact = 0
            train_examples = 0
            avg_steps_accumulator = 0.0
            epoch_start = time.time()

            for step, batch in enumerate(train_loader, 1):
                inputs = batch["inputs"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                targets = batch["targets"].to(device)
                task_ids = batch["task_ids"].to(device)

                optimizer.zero_grad(set_to_none=True)

                autocast_device_type = device.type if device.type in {"cuda", "cpu", "mps"} else "cuda"
                with autocast(device_type=autocast_device_type, enabled=scaler.is_enabled()):
                    logits, metadata = model(
                        inputs,
                        task_ids,
                        attention_mask=attention_mask,
                        dynamic_exit=args.train_dynamic_exit,
                        gate_threshold=args.exit_gate_threshold,
                    )
                    num_colors = logits.size(1)
                    logits_flat = logits.permute(0, 2, 3, 1).reshape(-1, num_colors)
                    loss = F.cross_entropy(
                        logits_flat,
                        targets.view(-1),
                        ignore_index=IGNORE_INDEX,
                    )

                    gate_entropy_loss, loop_penalty = _compute_gate_regularizers(metadata, args)
                    if args.gate_entropy_weight > 0:
                        loss = loss + args.gate_entropy_weight * gate_entropy_loss
                    if args.loop_penalty_weight > 0:
                        loss = loss + args.loop_penalty_weight * loop_penalty

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()

                running_loss += loss.item() * inputs.size(0)
                sample_count += inputs.size(0)
                avg_steps_accumulator += metadata.exit_steps.float().sum().item()

                predictions = logits.argmax(dim=1)
                for idx in range(inputs.size(0)):
                    target = targets[idx]
                    prediction = predictions[idx]
                    valid = target != IGNORE_INDEX
                    is_exact = bool(torch.equal(prediction[valid], target[valid])) if valid.any() else False
                    train_exact += int(is_exact)
                    train_examples += 1

                total_batches = len(train_loader)
                if total_batches > 0 and is_main_process and step % 10 == 0:
                    elapsed = time.time() - epoch_start
                    steps_completed = previous_total_steps + step
                    all_steps = len(train_loader) * args.epochs
                    elapsed_global = time.time() - global_start
                    avg_time_per_step_global = elapsed_global / max(steps_completed, 1)
                    remaining_steps = all_steps - steps_completed
                    eta = remaining_steps * avg_time_per_step_global
                    bar_length = 30
                    progress_ratio = steps_completed / all_steps if all_steps else 0
                    filled = int(bar_length * progress_ratio)
                    bar = "#" * filled + "-" * (bar_length - filled)
                    progress = 100.0 * progress_ratio
                    sys.stdout.write(f"\rEpoch {epoch} [{bar}] {progress:5.1f}% ETA {_format_eta(eta)}")
                    sys.stdout.flush()

            previous_total_steps += len(train_loader)
            if is_main_process:
                sys.stdout.write("\n")

            epoch_duration = time.time() - epoch_start if len(train_loader) > 0 else 0.0
            if distributed and dist.is_initialized():
                train_totals = torch.tensor(
                    [running_loss, sample_count, train_exact, train_examples, avg_steps_accumulator],
                    dtype=torch.float64,
                    device=device,
                )
                dist.all_reduce(train_totals, op=dist.ReduceOp.SUM)
                running_loss, sample_count, train_exact, train_examples, avg_steps_accumulator = train_totals.tolist()

            avg_train_loss = running_loss / max(sample_count, 1)
            train_acc = train_exact / max(train_examples, 1)
            avg_train_steps = avg_steps_accumulator / max(train_examples, 1)

            log_parts = [
                f"epoch={epoch}",
                f"train_loss={avg_train_loss:.4f}",
                f"train_acc={train_acc:.4f}",
                f"avg_train_steps={avg_train_steps:.2f}",
                f"epoch_time={epoch_duration:.1f}s",
            ]

            current_lr = optimizer.param_groups[0]["lr"] if optimizer.param_groups else args.learning_rate
            # Use scientific notation so very small learning rates are visible.
            log_parts.append(f"lr={current_lr:.6e}")

            eval_loss = None
            eval_acc = None
            eval_steps = None
            visualizations: Dict[int, Any] = {}
            if eval_loader is not None:
                eval_loss, eval_acc, eval_steps, visualizations = evaluate_gate(
                    model,
                    eval_loader,
                    device,
                    args,
                    distributed=distributed,
                )
                if is_main_process:
                    log_parts.append(f"eval_loss={eval_loss:.4f}")
                    log_parts.append(f"eval_acc={eval_acc:.4f}")
                    log_parts.append(f"eval_steps={eval_steps:.2f}")
                if eval_acc is not None and eval_acc > best_eval_acc and args.best_save_path and is_main_process:
                    best_eval_acc = eval_acc
                    best_payload: Dict[str, Any] = {
                        "model_state": model_for_eval.state_dict(),
                        "config": vars(args),
                        "epoch": epoch,
                        "best_eval_accuracy": best_eval_acc,
                        "optimizer_state": optimizer.state_dict(),
                        "scaler_state": scaler.state_dict() if scaler.is_enabled() else None,
                    }
                    if scheduler is not None:
                        best_payload["scheduler_state"] = scheduler.state_dict()
                    best_path = Path(args.best_save_path)
                    best_path.parent.mkdir(parents=True, exist_ok=True)
                    torch.save(best_payload, best_path)

            if is_main_process:
                print(" | ".join(log_parts))

            if wandb_run is not None and is_main_process:
                metrics = {
                    "epoch": epoch,
                    "train/loss": avg_train_loss,
                    "train/accuracy": train_acc,
                    "train/avg_steps": avg_train_steps,
                    "train/lr": current_lr,
                }
                if eval_loss is not None:
                    metrics.update(
                        {
                            "eval/loss": eval_loss,
                            "eval/accuracy": eval_acc,
                            "eval/avg_steps": eval_steps,
                        }
                    )
                wandb.log(metrics, step=epoch)

            if scheduler is not None:
                scheduler.step()

    finally:
        if wandb_run is not None:
            wandb_run.finish()
        if distributed and dist.is_initialized():
            dist.barrier()

    if args.save_path and ((not distributed) or rank == 0):
        save_path = Path(args.save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        payload = {
            "model_state": model_for_eval.state_dict(),
            "config": vars(args),
            "scaler_state": scaler.state_dict() if scaler.is_enabled() else None,
            "optimizer_state": optimizer.state_dict(),
        }
        torch.save(payload, save_path)

    if distributed and dist.is_initialized():
        dist.destroy_process_group()


def main():
    parser = build_parser()
    args = parser.parse_args()
    train(args)


if __name__ == "__main__":
    main()
