From 6f48eb3a1eae1b755b7a06822e21a97ae93df445 Mon Sep 17 00:00:00 2001 From: AmeenP Date: Wed, 15 Oct 2025 14:58:00 -0700 Subject: [PATCH] [ENG-2059] Implement crash-safe checkpointing for vf-eval Add SimpleCheckpoint system with automatic resume capability: Core Implementation: - verifiers/utils/checkpoint.py: SimpleCheckpoint class with immediate fsync writes - verifiers/scripts/eval.py: Integrated checkpointing with simplified CLI - Simplified to 3 parameters: --output-dir, --checkpoint-every, --seed Key Features: - Crash-safe: Immediate append + fsync for both successes and failures - Auto-resume: Signature-based validation, skips completed work - Auto-retry: Failed items automatically retried on resume - Always skip-on-error: Failures don't crash evaluation Exit Codes: - 0: All items completed successfully - 1: Some items failed (check failures.jsonl) - 2: Partial completion (interrupted, can resume) Files Created: - results.jsonl: All successful completions (append-only) - failures.jsonl: Current failures (snapshot at checkpoint) - manifest.json: Run config and counters (atomic writes) --- verifiers/scripts/eval.py | 380 +++++++++++++++++++++++++--------- verifiers/utils/checkpoint.py | 221 ++++++++++++++++++++ 2 files changed, 507 insertions(+), 94 deletions(-) create mode 100644 verifiers/utils/checkpoint.py diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index 628a690b5..328baba30 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -1,8 +1,11 @@ import argparse +import asyncio +import contextlib import importlib import importlib.util import json import logging +import sys import time import uuid from datetime import datetime @@ -11,10 +14,12 @@ import numpy as np from datasets import Dataset +from openai import AsyncOpenAI import verifiers as vf from verifiers import setup_logging from verifiers.types import Endpoints +from verifiers.utils.checkpoint import SimpleCheckpoint, RunConfig, _json_sha256 from verifiers.utils.client_utils import setup_client from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls @@ -22,6 +27,76 @@ logger = logging.getLogger("verifiers.scripts.eval") +def dataset_fingerprint(dataset: Dataset) -> str: + """Generate a fingerprint for a dataset.""" + # Use dataset info if available (HuggingFace datasets) + if hasattr(dataset, "info") and hasattr(dataset.info, "config_name"): + return f"{dataset.info.builder_name}@{dataset.info.config_name}" + # Fall back to hash of first few rows + sample = dataset[:min(5, len(dataset))] + return _json_sha256(sample) + + +def resolve_indices(dataset: Dataset, num_examples: int, seed: int | None) -> list[int]: + """Resolve dataset indices deterministically.""" + n = len(dataset) + if num_examples > 0: + n = min(num_examples, n) + indices = list(range(n)) + if seed is not None: + import random + + rng = random.Random(seed) + rng.shuffle(indices) + return indices + + +def aggregate_from_jsonl(path: Path) -> dict: + """Aggregate results from a JSONL file.""" + rewards = [] + metrics_lists: dict = {} + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + if row.get("status") == "ok": + rewards.append(row.get("metrics", {}).get("reward", 0.0)) + for k, v in row.get("metrics", {}).items(): + if k != "reward": + if k not in metrics_lists: + metrics_lists[k] = [] + metrics_lists[k].append(v) + except Exception as e: + logger.warning(f"Failed to parse line: {e}") + continue + result = {"rewards": rewards, "metrics": metrics_lists} + if rewards: + result["avg_reward"] = sum(rewards) / len(rewards) + result["std_reward"] = np.std(rewards) + return result + + +def default_run_dir(env_id: str, model_name: str) -> Path: + """Generate default run directory path.""" + module_name = env_id.replace("-", "_") + env_model_str = f"{env_id}--{model_name.replace('/', '--')}" + uuid_str = str(uuid.uuid4())[:8] + return Path("./outputs") / "evals" / env_model_str / uuid_str + + +def infer_stage(e: Exception) -> str: + """Infer which stage an error occurred in.""" + error_str = str(e).lower() + if "timeout" in error_str or "connection" in error_str: + return "inference" + if "parse" in error_str or "json" in error_str: + return "parse" + return "rubric" + + def eval_environment( env: str, env_args: dict, @@ -41,6 +116,10 @@ def eval_environment( save_to_hf_hub: bool, hf_hub_dataset_name: str, extra_headers: Dict[str, str], + # Simplified checkpoint parameters + output_dir: str | None, + checkpoint_every: int, + seed: int, ): setup_logging("DEBUG" if verbose else "INFO") try: @@ -97,6 +176,8 @@ def eval_environment( extra_headers=extra_headers, ) logger.debug(f"Initialized OpenAI client with base_url: {api_base_url}") + async_client = AsyncOpenAI(api_key=client.api_key, base_url=str(client.base_url)) + vf_env = vf.load_environment(env_id=env, **env_args) # Merge sampling args with precedence to JSON payload over explicit flags merged_sampling_args: dict = {} @@ -107,116 +188,201 @@ def eval_environment( if temperature is not None and "temperature" not in merged_sampling_args: merged_sampling_args["temperature"] = temperature - logger.info(f"Starting evaluation with model: {model}") - logger.info( - f"Configuration: num_examples={num_examples}, rollouts_per_example={rollouts_per_example}, max_concurrent={max_concurrent}" - ) - start_time = time.time() - results = vf_env.evaluate( - client=client, + # Get dataset and resolve indices + if vf_env.eval_dataset is None: + logger.info("eval_dataset is not set, falling back to train dataset") + dataset = vf_env.get_dataset(n=num_examples, seed=seed) + else: + dataset = vf_env.get_eval_dataset(n=num_examples, seed=seed) + + # Build deterministic work keys: "idx/roll" + indices = resolve_indices(dataset, num_examples, seed) + all_keys = [f"{i}/{r}" for i in indices for r in range(rollouts_per_example)] + + # Build run configuration for checkpointing + ds_fp = dataset_fingerprint(dataset) + idx_sha = _json_sha256(indices) + cfg = RunConfig( + env_id=env, + split="eval" if vf_env.eval_dataset is not None else "train", + env_args=env_args, + dataset_fingerprint=ds_fp, + indices_sha256=idx_sha, model=model, sampling_args=merged_sampling_args, - num_examples=num_examples, + num_examples=len(indices), rollouts_per_example=rollouts_per_example, + seed=seed, max_concurrent=max_concurrent, + verifiers_version=vf.__version__, + env_version=getattr(vf_env, "__version__", None), + ) + + # Setup output directory and checkpoint writer + run_dir = Path(output_dir) if output_dir else default_run_dir(env, model) + + # SimpleCheckpoint handles resume automatically based on manifest + cp = SimpleCheckpoint(run_dir, cfg, checkpoint_every=checkpoint_every) + + worklist = cp.pending_keys(all_keys) + if not worklist: + logger.info("All items already completed. Nothing to do.") + print("All items already completed. Nothing to do.") + return + + logger.info(f"Starting evaluation with model: {model}") + logger.info( + f"Configuration: num_examples={num_examples}, rollouts_per_example={rollouts_per_example}, max_concurrent={max_concurrent}" ) + logger.info(f"Pending items: {len(worklist)} / {len(all_keys)}") + start_time = time.time() + + # Run async evaluation with checkpointing + async def run_evaluation(): + sem = asyncio.Semaphore(max_concurrent) + + async def process(key: str): + idx, roll = map(int, key.split("/")) + example = dataset[idx] + async with sem: + try: + prompt = example["prompt"] + answer = example.get("answer", "") + task = example.get("task", "default") + info = example.get("info", {}) + + # Run single rollout + completion, state = await vf_env.rollout( + client=async_client, + model=model, + prompt=prompt, + answer=answer, + task=task, + info=info, + sampling_args=merged_sampling_args, + ) + + # Score the rollout + rollout_score = await vf_env.rubric.score_rollout( + prompt=prompt, + completion=completion, + answer=answer, + state=state, + task=task, + info=info, + ) + + # Record success + await cp.queue.put( + ( + "ok", + { + "key": key, + "idx": idx, + "rollout": roll, + "status": "ok", + "request": {"prompt": prompt, "sampling_args": merged_sampling_args}, + "completion": sanitize_tool_calls(completion), + "parsed": state.get("parsed", {}), + "metrics": { + "reward": rollout_score.reward, + **rollout_score.metrics, + }, + "timing": state.get("timing", {}), + "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ"), + }, + ) + ) + except Exception as e: + # Always skip-on-error (simplified design) + await cp.queue.put(("error", { + "key": key, + "idx": idx, + "rollout": roll, + "status": "error", + "stage": infer_stage(e), + "error": f"{type(e).__name__}: {e}", + "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ"), + })) + + writer_task = asyncio.create_task(cp.run(total_items=len(all_keys))) + try: + from tqdm.asyncio import tqdm_asyncio + + await tqdm_asyncio.gather( + *[process(k) for k in worklist], + total=len(worklist), + desc="Running rollouts", + ) + finally: + writer_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await writer_task + + # Execute async evaluation + try: + loop = asyncio.get_running_loop() + import nest_asyncio # type: ignore + + nest_asyncio.apply() + loop.run_until_complete(run_evaluation()) + except RuntimeError: + asyncio.run(run_evaluation()) + end_time = time.time() logger.info(f"Evaluation completed in {end_time - start_time:.2f} seconds") + + # Aggregate results from JSONL + summary = aggregate_from_jsonl(run_dir / "results.jsonl") + print("--- Evaluation ---") print(f"Environment: {env}") print(f"Model: {model}") print(f"Provider: {api_base_url}") print(f"Examples: {num_examples}") print(f"Rollouts per example: {rollouts_per_example}") - print("--- Example ---") - printable_prompts = [messages_to_printable(p) for p in results.prompt] - printable_completions = [messages_to_printable(c) for c in results.completion] - vf.print_prompt_completions_sample( - printable_prompts, printable_completions, results.reward, step=0 - ) - print("--- All ---") - print("Rewards:") - print( - f"reward: avg - {sum(results.reward) / len(results.reward):.3f}, std - {np.std(results.reward):.3f}" - ) - r = rollouts_per_example - n = len(results.reward) // r - for i in range(r): - # rounded to 3 decimal places - trials = [round(results.reward[(i * n) + j], 3) for j in range(n)] - out = f"r{i + 1}: {trials}" - print(out) - for k in results.metrics: - v = results.metrics[k] - print(f"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}") + print(f"Output directory: {run_dir}") + print() + print("--- Summary ---") + if "avg_reward" in summary: + print( + f"reward: avg - {summary['avg_reward']:.3f}, std - {summary['std_reward']:.3f}" + ) + print(f"Total completed: {len(summary['rewards'])}") + print(f"Total failed: {cp.num_failed}") + else: + print("No results found.") + + # Print per-rollout breakdown if we have enough data + if "rewards" in summary and len(summary["rewards"]) >= rollouts_per_example: + rewards = summary["rewards"] + r = rollouts_per_example + n = len(rewards) // r + print("\nRewards by rollout:") for i in range(r): - # rounded to 3 decimal places - trials = [round(v[(i * n) + j], 3) for j in range(n)] - out = f"r{i + 1}: {trials}" - print(out) + trials = [round(rewards[(i * n) + j], 3) for j in range(n)] + print(f"r{i + 1}: {trials[:10]}{'...' if len(trials) > 10 else ''}") + # Print metrics breakdown + for k, v in summary.get("metrics", {}).items(): + if v: + print(f"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}") + + # Note: --save-dataset and --save-to-hf-hub are not yet implemented with checkpointing + # The results are already saved to the output directory in JSONL format if save_dataset or save_to_hf_hub: - ids = [i // rollouts_per_example for i in range(n * rollouts_per_example)] - rewards = results.reward - tasks = results.task - data_dict = { - "id": ids, - "prompt": [sanitize_tool_calls(p) for p in printable_prompts], - "completion": [sanitize_tool_calls(c) for c in printable_completions], - "task": tasks, - "generation_ms": [s["timing"]["generation_ms"] for s in results.state], - "scoring_ms": [s["timing"]["scoring_ms"] for s in results.state], - "total_ms": [s["timing"]["total_ms"] for s in results.state], - } - if results.info[0] != {}: - data_dict["info"] = results.info - if results.answer[0] != "": - data_dict["answer"] = results.answer - data_dict["reward"] = rewards - for k in results.metrics: - v = results.metrics[k] - data_dict[k] = v - - dataset = Dataset.from_dict(data_dict) - metadata = { - "env": env, - "model": model, - "num_examples": n, - "rollouts_per_example": rollouts_per_example, - "sampling_args": merged_sampling_args, - "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "time_ms": (end_time - start_time) * 1000, - "avg_reward": sum(results.reward) / len(results.reward), - } - for k in results.metrics: - metadata[f"avg_{k}"] = sum(results.metrics[k]) / len(results.metrics[k]) - - uuid_str = str(uuid.uuid4())[:8] - env_model_str = f"{env}--{model.replace('/', '--')}" - if save_dataset: - module_name = env.replace("-", "_") - local_env_dir = Path(env_dir_path) / module_name - if local_env_dir.exists(): - results_path = ( - local_env_dir / "outputs" / "evals" / env_model_str / uuid_str - ) - else: - results_path = Path("./outputs") / "evals" / env_model_str / uuid_str - results_path.parent.mkdir(parents=True, exist_ok=True) - dataset.to_json(results_path / "results.jsonl") - with open(results_path / "metadata.json", "w") as f: - json.dump(metadata, f) - - logger.info(f"Saved dataset to {results_path}") - if save_to_hf_hub: - if hf_hub_dataset_name == "": - dataset_name = ( - f"{env}_{model.replace('/', '-')}_n{n}_r{rollouts_per_example}" - ) - else: - dataset_name = hf_hub_dataset_name - dataset.push_to_hub(dataset_name) - logger.info(f"Saved dataset to Hugging Face Hub: {dataset_name}") + logger.warning( + "--save-dataset and --save-to-hf-hub are not yet implemented with checkpointing. " + f"Results are saved in JSONL format at {run_dir}" + ) + + # Exit with appropriate code + if cp.num_failed > 0: + pending = len(cp.pending_keys(all_keys)) + if pending == 0: + sys.exit(1) # completed with failures + else: + sys.exit(2) # interrupted/partial def main(): @@ -337,6 +503,28 @@ def main(): default="", help="Name of dataset to save to Hugging Face Hub", ) + + # Checkpointing arguments (simplified) + checkpoint_group = parser.add_argument_group("checkpointing") + checkpoint_group.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory to write artifacts; if it already exists, resume automatically (default: auto-generated)", + ) + checkpoint_group.add_argument( + "--checkpoint-every", + type=int, + default=50, + help="Rewrite failures + manifest every N finished items (default: 50)", + ) + checkpoint_group.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for deterministic dataset ordering (default: 42)", + ) + args = parser.parse_args() # Build headers from repeated --header flags @@ -369,6 +557,10 @@ def main(): save_to_hf_hub=args.save_to_hf_hub, hf_hub_dataset_name=args.hf_hub_dataset_name, extra_headers=merged_headers, + # Simplified checkpoint parameters + output_dir=args.output_dir, + checkpoint_every=args.checkpoint_every, + seed=args.seed, ) diff --git a/verifiers/utils/checkpoint.py b/verifiers/utils/checkpoint.py new file mode 100644 index 000000000..35f9bcbea --- /dev/null +++ b/verifiers/utils/checkpoint.py @@ -0,0 +1,221 @@ +# verifiers/utils/checkpoint.py +from __future__ import annotations + +import asyncio +import contextlib +import hashlib +import json +import os +import tempfile +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple + + +def _json_sha256(obj: Any) -> str: + """Compute SHA256 hash of a JSON-serializable object.""" + b = json.dumps(obj, sort_keys=True, separators=(",", ":")).encode("utf-8") + return "sha256:" + hashlib.sha256(b).hexdigest() + + +def _atomic_write_json(path: Path, obj: Dict[str, Any]) -> None: + """Atomically write a JSON object to a file.""" + path.parent.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile( + "w", delete=False, dir=str(path.parent), encoding="utf-8" + ) as f: + json.dump(obj, f) + f.flush() + os.fsync(f.fileno()) + tmp = Path(f.name) + os.replace(tmp, path) + + +def _atomic_write_text(path: Path, text: str) -> None: + """Atomically write text to a file.""" + path.parent.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile( + "w", delete=False, dir=str(path.parent), encoding="utf-8" + ) as f: + f.write(text) + f.flush() + os.fsync(f.fileno()) + tmp = Path(f.name) + os.replace(tmp, path) + + +def _scan_success_keys(results_path: Path) -> Set[str]: + """Scan results.jsonl for successful completion keys (ground truth for resume).""" + done: Set[str] = set() + if not results_path.exists(): + return done + with results_path.open("r", encoding="utf-8", errors="ignore") as f: + for line in f: + line = line.strip() + if not line: + continue + with contextlib.suppress(Exception): + row = json.loads(line) + k = row.get("key") + if isinstance(k, str): + done.add(k) + return done + + +@dataclass +class RunConfig: + """Configuration for a checkpoint run.""" + + env_id: str + split: str + env_args: Dict[str, Any] + dataset_fingerprint: str + indices_sha256: str + model: str + sampling_args: Dict[str, Any] + num_examples: int + rollouts_per_example: int + seed: int + max_concurrent: int + verifiers_version: str + env_version: Optional[str] = None + + +class SimpleCheckpoint: + """ + Single-writer checkpoint manager with simplified semantics: + - Appends successes and failures immediately per-item (crash-safe) + - Rewrites failures.jsonl at checkpoints to contain only *current* failures + - Auto-resumes based on results.jsonl (ground truth) + - Always skip-on-error (failures don't crash the run) + """ + + def __init__( + self, + out_dir: Path, + run_cfg: RunConfig, + checkpoint_every: int = 50, + ): + self.out_dir = out_dir + self.out_dir.mkdir(parents=True, exist_ok=True) + self.results = out_dir / "results.jsonl" + self.failures = out_dir / "failures.jsonl" + self.manifest = out_dir / "manifest.json" + self.cfg = asdict(run_cfg) + self.signature = _json_sha256(self.cfg) + + # Validate signature if resuming + if self.manifest.exists(): + prev = json.loads(self.manifest.read_text()) + if prev.get("signature") != self.signature: + raise SystemExit( + "Output directory contains a different run (signature mismatch). " + "Choose a new --output-dir to start fresh." + ) + + # Resume: successes define ground truth of 'done' + self.completed: Set[str] = _scan_success_keys(self.results) + + # Current failure records: key -> last error record + self.fail_records: Dict[str, Dict[str, Any]] = {} + + # Single writer queue + self.queue: "asyncio.Queue[Tuple[str, Dict[str, Any]]]" = asyncio.Queue() + + # Open files for append + self._ok = self.results.open("a", encoding="utf-8", buffering=1) + self._ko = self.failures.open("a", encoding="utf-8", buffering=1) + + self._checkpoint_every = max(1, int(checkpoint_every)) + self._since = 0 + + # Initial manifest + self._write_manifest(total_items=0) + + @property + def num_done(self) -> int: + """Total items processed (successes + failures).""" + return len(self.completed) + len(self.fail_records) + + @property + def num_failed(self) -> int: + """Current failure count.""" + return len(self.fail_records) + + def pending_keys(self, all_keys: Iterable[str]) -> List[str]: + """Return keys not yet completed (failures are retried automatically).""" + return [k for k in all_keys if k not in self.completed] + + async def run(self, total_items: int) -> None: + """Main writer loop: consume queue and write items.""" + try: + while True: + kind, rec = await self.queue.get() + k = rec["key"] + + if kind == "ok": + # Append success immediately + self._ok.write(json.dumps(rec, ensure_ascii=False) + "\n") + self._ok.flush() + os.fsync(self._ok.fileno()) + self.completed.add(k) + # If it previously failed, forget the failure + self.fail_records.pop(k, None) + else: + # Append failure immediately for crash-safety + self._ko.write(json.dumps(rec, ensure_ascii=False) + "\n") + self._ko.flush() + os.fsync(self._ko.fileno()) + self.fail_records[k] = rec + + self._since += 1 + if self._since >= self._checkpoint_every: + self._checkpoint(total_items) + finally: + # Final checkpoint on cancellation/exit + self._checkpoint(total_items) + self._ok.close() + self._ko.close() + + def _checkpoint(self, total_items: int) -> None: + """Rewrite failures.jsonl as snapshot of current failures + update manifest.""" + # 1) Rewrite failures.jsonl as a *snapshot* of current failures + self._ko.close() # close append handle before replacing the file + + snapshot = "" + if self.fail_records: + snapshot = "\n".join( + json.dumps(v, ensure_ascii=False) for v in self.fail_records.values() + ) + "\n" + _atomic_write_text(self.failures, snapshot) + + # Reopen append handle for subsequent immediate failure writes + self._ko = self.failures.open("a", encoding="utf-8", buffering=1) + + # 2) Update manifest atomically + self._write_manifest(total_items) + self._since = 0 + + def _write_manifest(self, total_items: int) -> None: + """Write manifest with current counters.""" + man = { + "version": 1, + "signature": self.signature, + "config": self.cfg, + "counters": { + "total": total_items, + "done": len(self.completed), + "failed": len(self.fail_records), + }, + "paths": { + "results": str(self.results), + "failures": str(self.failures), + }, + "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ"), + } + _atomic_write_json(self.manifest, man) + + def immediate_flush(self, total_items: int): + """Force an immediate checkpoint (for graceful shutdown).""" + self._checkpoint(total_items)