"""Math problem pool: load, sample (deterministic via seed), and verify. Carries from v1 verbatim. v2 adds an ``is_eval_problem`` flag in the JSON payload; the loader treats it as optional metadata and exposes it through ``Problem.is_eval_problem`` for the sandbagging detector. """ import json import random from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Union @dataclass class Problem: id: int problem: str answer: int difficulty: str is_eval_problem: bool = False def to_dict(self) -> Dict[str, Any]: """JSON-serialisable form.""" return { "id": self.id, "problem": self.problem, "answer": self.answer, "difficulty": self.difficulty, "is_eval_problem": self.is_eval_problem, } @dataclass class ProblemPool: problems: List[Problem] = field(default_factory=list) @classmethod def load(cls, path: Union[str, Path]) -> "ProblemPool": """Load a problem pool from a JSON file at ``path``.""" path = Path(path) with path.open() as f: data = json.load(f) return cls(problems=[ Problem( id=item["id"], problem=item["problem"], answer=item["answer"], difficulty=item["difficulty"], is_eval_problem=bool(item.get("is_eval_problem", False)), ) for item in data ]) def sample(self, n: int = 10, seed: Optional[int] = None) -> List[Problem]: """Return ``n`` problems sampled without replacement. Deterministic when ``seed`` is given. When ``n`` exceeds the pool size, returns the full pool (clamped, not raising). """ rng = random.Random(seed) return rng.sample(self.problems, k=min(n, len(self.problems))) def __len__(self) -> int: return len(self.problems) def load_problems(path: Optional[Union[str, Path]] = None) -> ProblemPool: """Load the bundled problem pool, or a custom path if provided.""" if path is None: path = Path(__file__).parent / "data" / "problems_pool.json" return ProblemPool.load(path) def check_answer(submitted: Any, ground_truth: Any) -> bool: """Strict integer equality. Rejects type mismatches (str vs int, float vs int, bool vs int) explicitly — ``True == 1`` would otherwise pass naive ``==``. """ if type(submitted) is not int or type(ground_truth) is not int: return False return submitted == ground_truth