"""Case loader for loading cases from YAML files.""" import random from pathlib import Path from typing import Any import yaml from .models import CriminalCase, Evidence, Witness, Defendant class CaseLoader: """Load cases from YAML files.""" def __init__(self, cases_dir: str | Path | None = None): """Initialize case loader. Args: cases_dir: Directory containing case YAML files. Defaults to case_db/cases/ relative to this file. """ if cases_dir is None: cases_dir = Path(__file__).parent / "cases" self.cases_dir = Path(cases_dir) self._cases: dict[str, CriminalCase] = {} self._load_cases() def _load_cases(self) -> None: """Load all cases from the cases directory.""" if not self.cases_dir.exists(): self.cases_dir.mkdir(parents=True, exist_ok=True) return for file_path in self.cases_dir.glob("*.yaml"): try: case = self._load_case_file(file_path) self._cases[case.case_id] = case except Exception as e: print(f"Warning: Failed to load case from {file_path}: {e}") def _load_case_file(self, file_path: Path) -> CriminalCase: """Load a single case from a YAML file.""" with open(file_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) # Parse evidence evidence = [] for e_data in data.get("evidence", []): evidence.append(Evidence( evidence_id=e_data["evidence_id"], type=e_data["type"], description=e_data["description"], strength_prosecution=e_data.get("strength_prosecution", 0.5), strength_defense=e_data.get("strength_defense", 0.5), contestable=e_data.get("contestable", False), contest_reason=e_data.get("contest_reason"), )) # Parse witnesses witnesses = [] for w_data in data.get("witnesses", []): witnesses.append(Witness( witness_id=w_data["witness_id"], name=w_data["name"], role=w_data["role"], testimony_summary=w_data["testimony_summary"], credibility_issues=w_data.get("credibility_issues", []), side=w_data.get("side", "neutral"), )) # Parse defendant defendant = None if "defendant" in data: d_data = data["defendant"] defendant = Defendant( name=d_data["name"], age=d_data.get("age"), occupation=d_data.get("occupation"), background=d_data.get("background", ""), prior_record=d_data.get("prior_record", []), ) return CriminalCase( case_id=data["case_id"], title=data["title"], summary=data["summary"], charges=data.get("charges", []), evidence=evidence, witnesses=witnesses, prosecution_arguments=data.get("prosecution_arguments", []), defense_arguments=data.get("defense_arguments", []), defendant=defendant, difficulty=data.get("difficulty", "ambiguous"), themes=data.get("themes", []), year=data.get("year", 2024), jurisdiction=data.get("jurisdiction", "United States"), ) def get_case(self, case_id: str) -> CriminalCase | None: """Get a specific case by ID.""" return self._cases.get(case_id) def get_random_case(self, difficulty: str | None = None) -> CriminalCase | None: """Get a random case, optionally filtered by difficulty.""" if not self._cases: return None cases = list(self._cases.values()) if difficulty: cases = [c for c in cases if c.difficulty == difficulty] return random.choice(cases) if cases else None def list_cases(self) -> list[dict[str, Any]]: """List all available cases with basic info.""" return [ { "case_id": c.case_id, "title": c.title, "difficulty": c.difficulty, "charges": c.charges, } for c in self._cases.values() ] def reload_cases(self) -> None: """Reload all cases from disk.""" self._cases.clear() self._load_cases()