# compare_loop_vs_baseline.py
import argparse
import json
import os
from pathlib import Path
from typing import Dict, List

from utils.eval_utils import get_majority_vote
from arc1_tasks import tasks_arc_agi  # 新增这行


def load_task_list(task_type: str) -> List[str]:
    # 直接复用 analysis.py 中的 ARC-AGI 任务列表
    if task_type == "ARC-AGI":
        return tasks_arc_agi
    else:
        raise ValueError(f"Unsupported task_type={task_type}")


def load_predictions_for_roots(task_names: List[str], roots: List[str]) -> Dict[str, Dict]:
    """
    读取多个 output_root（attempt 目录），合并到 answer_set 结构：
    answer_set[task_name][example_id] = [pred_1, pred_2, ...]
    结构和 analysis.py 里处理过的 all_data 类似。
    """
    answer_set: Dict[str, Dict[str, List]] = {}

    for name in task_names:
        merged: Dict[str, List] = {}
        has_data = False

        for root in roots:
            root = root.strip()
            cur_task_save = os.path.join(root, f"{name}_predictions.json")
            if not os.path.exists(cur_task_save):
                continue

            with open(cur_task_save, "r") as f:
                data = json.load(f)  # {example_id: [...], "answer": [...]}

            if not merged:
                merged = {k: [] for k in data.keys()}

            for k in data.keys():
                merged[k].extend(data[k])
            has_data = True

        if has_data:
            answer_set[name] = merged

    return answer_set


def load_ground_truths(task_names: List[str], task_type: str) -> Dict[str, Dict[str, List]]:
    gt: Dict[str, Dict[str, List]] = {}
    base_dir = Path("raw_data") / task_type / "data" / "evaluation"
    for task_name in task_names:
        json_path = base_dir / f"{task_name}.json"
        if not json_path.exists():
            continue
        with open(json_path, "r") as f:
            cur_data = json.load(f)
        test_data = cur_data["test"]
        gt[task_name] = {str(i): item["output"] for i, item in enumerate(test_data)}
    return gt


def build_majority_pred_map(answer_set: Dict[str, Dict], task_names: List[str]) -> Dict[str, Dict[str, List]]:
    """
    对每个 task/example，取 majority_vote 第一名预测作为该模型的“代表预测”。
    返回结构：pred_map[task_name][example_id] = prediction
    """
    pred_map: Dict[str, Dict[str, List]] = {}
    for task_name in task_names:
        if task_name not in answer_set:
            continue
        task_preds = answer_set[task_name]
        pred_map[task_name] = {}
        for ex_id, preds in task_preds.items():
            if ex_id == "answer":
                continue
            mv = get_majority_vote(preds)
            if len(mv) == 0:
                continue
            best = mv[0]["prediction"]  # 第一名
            pred_map[task_name][ex_id] = best
    return pred_map


def compare_models(
    baseline_roots: List[str],
    loop_roots: List[str],
    task_type: str,
    out_path: Path,
) -> None:
    tasks = load_task_list(task_type)
    gt = load_ground_truths(tasks, task_type)

    # 1) 加载 baseline & loop 所有 attempts 的预测
    baseline_ans = load_predictions_for_roots(tasks, baseline_roots)
    loop_ans = load_predictions_for_roots(tasks, loop_roots)

    # 2) 求 majority 第一名预测
    baseline_pred = build_majority_pred_map(baseline_ans, tasks)
    loop_pred = build_majority_pred_map(loop_ans, tasks)

    # 3) 找出 baseline 错、loop 对 的 (task, ex_id)
    good_cases = []  # list of dicts
    for task_name in tasks:
        if task_name not in gt:
            continue
        for ex_id, gt_out in gt[task_name].items():
            b_pred = baseline_pred.get(task_name, {}).get(ex_id)
            l_pred = loop_pred.get(task_name, {}).get(ex_id)
            if b_pred is None or l_pred is None:
                continue

            is_b_correct = (b_pred == gt_out)
            is_l_correct = (l_pred == gt_out)

            if (not is_b_correct) and is_l_correct:
                good_cases.append(
                    {
                        "task": task_name,
                        "example_id": ex_id,
                        "gt": gt_out,
                        "baseline_pred": b_pred,
                        "loop_pred": l_pred,
                    }
                )

    print(f"Found {len(good_cases)} cases where Loop is correct and Baseline is wrong.")

    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w") as f:
        json.dump(good_cases, f, indent=2)
    print(f"Saved cases to {out_path.resolve()}")


def main():
    parser = argparse.ArgumentParser("Compare Loop vs Baseline on ARC-AGI")
    parser.add_argument(
        "--baseline-roots",
        type=str,
        required=True,
        help="Baseline prediction output_root(s), comma-separated, e.g. 'outputs/ARC_1_eval_ViT_attempt_0'",
    )
    parser.add_argument(
        "--loop-roots",
        type=str,
        required=True,
        help="Loop prediction output_root(s), comma-separated, e.g. 'outputs/ARC_1_eval_LoopViT_attempt_0,outputs/ARC_1_eval_LoopViT_attempt_1'",
    )
    parser.add_argument(
        "--task-type",
        type=str,
        default="ARC-AGI",
    )
    parser.add_argument(
        "--out-path",
        type=Path,
        default=Path("loop_better_than_baseline_cases.json"),
    )
    args = parser.parse_args()

    baseline_roots = [x.strip() for x in args.baseline_roots.split(",") if x.strip()]
    loop_roots = [x.strip() for x in args.loop_roots.split(",") if x.strip()]

    compare_models(
        baseline_roots=baseline_roots,
        loop_roots=loop_roots,
        task_type=args.task_type,
        out_path=args.out_path,
    )


if __name__ == "__main__":
    main()