diff --git a/garden_ai/backend_client.py b/garden_ai/backend_client.py index 51893662..925029c6 100644 --- a/garden_ai/backend_client.py +++ b/garden_ai/backend_client.py @@ -6,6 +6,10 @@ from garden_ai.constants import GardenConstants from garden_ai.gardens import Garden +from garden_ai.schemas.benchmark import ( + BenchmarkResultCreateRequest, + BenchmarkResultResponse, +) from garden_ai.schemas.garden import GardenMetadata from garden_ai.schemas.hpc import HpcInvocationCreateRequest from garden_ai.schemas.modal import ( @@ -182,3 +186,9 @@ def search_gardens(self, payload: dict) -> dict: def create_hpc_invocation(self, payload: HpcInvocationCreateRequest) -> dict: response = self._post("/hpc/invocations", payload.model_dump(mode="json")) return response + + def publish_benchmark_result( + self, payload: BenchmarkResultCreateRequest + ) -> BenchmarkResultResponse: + response = self._post("/benchmarks", payload.model_dump(mode="json")) + return BenchmarkResultResponse(**response) diff --git a/garden_ai/benchmarks/__init__.py b/garden_ai/benchmarks/__init__.py new file mode 100644 index 00000000..c3281dcd --- /dev/null +++ b/garden_ai/benchmarks/__init__.py @@ -0,0 +1,110 @@ +"""Garden AI benchmarking framework. + +This module provides interfaces for running standardized benchmarks on +models hosted in Garden AI or developed locally. + +Available benchmarks: + - MatbenchDiscovery: Materials discovery benchmark suite +""" + +from typing import Any, Dict, Optional + +from garden_ai.client import GardenClient +from garden_ai.schemas.benchmark import BenchmarkResultCreateRequest + +from .matbench_discovery.enums import DatasetSize, MatbenchTask +from .matbench_discovery.tasks import MatbenchDiscovery + +__all__ = [ + "MatbenchDiscovery", + "MatbenchTask", + "DatasetSize", + "publish_benchmark_result", +] + + +def publish_benchmark_result( + result: Dict[str, Any], + model_name: str, + garden_doi: Optional[str] = None, + benchmark_name: Optional[str] = None, + task_name: Optional[str] = None, +) -> Dict[str, Any]: + """Publish benchmark results to the Garden AI backend. + + This function takes the output from a benchmark task (e.g., MatbenchDiscovery.IS2RE.remote()) + and publishes it to the Garden backend for tracking and leaderboard purposes. + + Args: + result: The output dictionary from a benchmark task. Should contain: + - 'metrics': Dictionary of benchmark metrics (F1, DAF, MAE, etc.) + - 'run_metadata': Optional run metadata (hardware, timing, cost) + - '_benchmark_info': Auto-injected benchmark/task info (if from wrapped method) + model_name: The specific name/variant of the model (e.g., "mace-mp-0-medium", "chgnet-v0.3.0"). + This is required to identify the model on the leaderboard. + garden_doi: Optional DOI for the Garden publication associated with this benchmark result. + benchmark_name: Override for benchmark name (defaults to auto-detected from result) + task_name: Override for task name (defaults to auto-detected from result) + + Returns: + Dictionary containing the response from the backend, including the result ID. + + Raises: + ValueError: If benchmark_name or task_name cannot be determined. + requests.HTTPError: If the backend request fails. + + Example: + ```python + from garden_ai.benchmarks import MatbenchDiscovery, publish_benchmark_result + + # Run a benchmark + output = MatbenchDiscovery.IS2RE.remote(...) + + # Publish the results + response = publish_benchmark_result(output, model_name="mace-medium", garden_doi="10.26311/example.doi") + print(f"Published with ID: {response['id']}") + ``` + """ + # Extract benchmark info from result or use provided overrides + benchmark_info = result.get("_benchmark_info", {}) + + final_benchmark_name = benchmark_name or benchmark_info.get("benchmark_name") + final_task_name = task_name or benchmark_info.get("task_name") + + if not final_benchmark_name: + raise ValueError( + "benchmark_name is required. Either pass it explicitly or use a result " + "from a MatbenchDiscovery task method (e.g., MatbenchDiscovery.IS2RE.remote())." + ) + + if not final_task_name: + raise ValueError( + "task_name is required. Either pass it explicitly or use a result " + "from a MatbenchDiscovery task method (e.g., MatbenchDiscovery.IS2RE.remote())." + ) + + # Inject model name into run_metadata + if "run_metadata" not in result: + result["run_metadata"] = {} + if "model" not in result["run_metadata"]: + result["run_metadata"]["model"] = {} + + result["run_metadata"]["model"]["variant"] = model_name + + # Inject garden_doi if provided + if garden_doi: + result["run_metadata"]["garden_doi"] = garden_doi + + # Create the request payload + # Note: We pass the modified result (containing metrics and metadata) as 'metrics' + # This assumes the backend handles the unified blob or we rely on the schema field description. + payload = BenchmarkResultCreateRequest( + benchmark_name=final_benchmark_name, + benchmark_task_name=final_task_name, + metrics=result, + ) + + # Get authenticated client and publish + client = GardenClient() + response = client.backend_client.publish_benchmark_result(payload) + return response.model_dump() diff --git a/garden_ai/benchmarks/matbench_discovery/__init__.py b/garden_ai/benchmarks/matbench_discovery/__init__.py new file mode 100644 index 00000000..1a5b9516 --- /dev/null +++ b/garden_ai/benchmarks/matbench_discovery/__init__.py @@ -0,0 +1,10 @@ +"""Matbench Discovery benchmark adapter for Garden AI.""" + +from .enums import DatasetSize, MatbenchTask +from .tasks import MatbenchDiscovery + +__all__ = [ + "MatbenchDiscovery", + "MatbenchTask", + "DatasetSize", +] diff --git a/garden_ai/benchmarks/matbench_discovery/enums.py b/garden_ai/benchmarks/matbench_discovery/enums.py new file mode 100644 index 00000000..fed3b514 --- /dev/null +++ b/garden_ai/benchmarks/matbench_discovery/enums.py @@ -0,0 +1,53 @@ +"""Enums for Matbench Discovery benchmark tasks.""" + +from enum import Enum + + +class MatbenchTask(Enum): + """Available Matbench Discovery benchmark tasks.""" + + IS2RE = "IS2RE" # Initial Structure to Relaxed Energy + RS2RE = "RS2RE" # Relaxed Structure to Relaxed Energy + S2EFS = "S2EFS" # Structure to Energy, Forces, Stress + S2EF = "S2EF" # Structure to Energy, Force + S2EFSM = "S2EFSM" # Structure to Energy, Force, Stress, Magmoms + IS2E = "IS2E" # Initial Structure to Energy + S2E = "S2E" # Structure to Energy + S2RE = "S2RE" # Structure to Relaxed Energy + RP2RE = "RP2RE" # Relaxed Prototype to Relaxed Energy + IP2E = "IP2E" # Initial Prototype to Energy + + +class DatasetSize(str, Enum): + """Predefined dataset sizes for Matbench Discovery benchmarks. + + These correspond to different subsets of the WBM test set that are commonly + used for evaluating materials discovery models. + """ + + FULL = "full" + """Full WBM test set (~257k structures)""" + + UNIQUE_PROTOS = "unique_protos" + """Unique prototypes subset (~215k structures) - removes duplicate prototypes""" + + RANDOM_10K = "random_10k" + """Random 10k structures from the unique prototypes subset (fixed seed)""" + + RANDOM_100 = "random_100" + """Random 100 structures for quick testing (fixed seed)""" + + def seed(self, seed: int) -> "DatasetConfig": + """Return a configuration with a custom random seed.""" + return DatasetConfig(self, seed) + + +class DatasetConfig: + """Configuration for a dataset subset with a specific random seed.""" + + def __init__(self, subset: DatasetSize, seed: int): + self.subset = subset + self.seed = seed + + def __repr__(self): + return f"{self.subset.name}(seed={self.seed})" diff --git a/garden_ai/benchmarks/matbench_discovery/examples/local_execution.py b/garden_ai/benchmarks/matbench_discovery/examples/local_execution.py new file mode 100644 index 00000000..1f482904 --- /dev/null +++ b/garden_ai/benchmarks/matbench_discovery/examples/local_execution.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +"""Matbench Discovery Benchmark - Local Execution Example""" + +from garden_ai.benchmarks.matbench_discovery import MatbenchDiscovery + + +def create_mattersim_model(device): + from mattersim.forcefield import MatterSimCalculator + + return MatterSimCalculator(device=device) + + +print("Running MatterSim benchmark locally...") + +# Run IS2RE task locally +# Note: Requires a GPU or MPS if using MatterSim, or CPU if specified/supported +output = MatbenchDiscovery.IS2RE.local( + model_factory=create_mattersim_model, + model_packages="mattersim", + num_structures="random_100", +) + +if "error" in output.get("metrics", {}): + print(f"Error: {output['metrics']['error']}") +else: + print("Benchmark Results:", output.get("metrics")) diff --git a/garden_ai/benchmarks/matbench_discovery/examples/matbench_equiformerv2.py b/garden_ai/benchmarks/matbench_discovery/examples/matbench_equiformerv2.py new file mode 100644 index 00000000..b32cfed0 --- /dev/null +++ b/garden_ai/benchmarks/matbench_discovery/examples/matbench_equiformerv2.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +"""Matbench Discovery Benchmark - EquiformerV2 Example""" + +from garden_ai.benchmarks.matbench_discovery import MatbenchDiscovery + + +def create_equiformerv2_model(device): + from fairchem.core.calculate.ase_calculator import Calculator # type: ignore + + # Use pre-trained checkpoint - will auto-download from HuggingFace + return Calculator( + model_name="EquiformerV2-31M-S2EF-OC20-All+MD", cpu=(device == "cpu") + ) + + +# Run S2EFS task (structure to energy/forces/stress) +output = MatbenchDiscovery.S2EFS.remote( + endpoint="anvil", + account="your-account-here", + model_factory=create_equiformerv2_model, + model_packages="fairchem-core", + num_structures="random_10k", +) + +if "error" in output.get("metrics", {}): + print(f"Error: {output['metrics']['error']}") +else: + print("Benchmark Results:", output.get("metrics")) diff --git a/garden_ai/benchmarks/matbench_discovery/examples/matbench_mace_multi_gpu.py b/garden_ai/benchmarks/matbench_discovery/examples/matbench_mace_multi_gpu.py new file mode 100644 index 00000000..90e7134b --- /dev/null +++ b/garden_ai/benchmarks/matbench_discovery/examples/matbench_mace_multi_gpu.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +"""Matbench Discovery Benchmark - MACE Multi-GPU Example""" + +from rich import print + +from garden_ai.benchmarks.matbench_discovery import MatbenchDiscovery + + +def create_mace_model(device): + from mace.calculators import mace_mp + + return mace_mp(model="medium-mpa-0", device=device, default_dtype="float64") + + +print("Running MACE benchmark on endpoint anvil...") + +results = MatbenchDiscovery.IS2RE.remote( + endpoint="anvil", + account="cis250461-gpu", + model_factory=create_mace_model, + model_packages=[ + "mace-torch", + "cuequivariance", + "cuequivariance-torch", + "cuequivariance-ops-torch-cu12", + ], + num_structures="random_100", +) + +if "error" in results.get("metrics", {}): + print(f"Error: {results['metrics']['error']}") +else: + print("Benchmark Results:", results) diff --git a/garden_ai/benchmarks/matbench_discovery/examples/matbench_mattersim.py b/garden_ai/benchmarks/matbench_discovery/examples/matbench_mattersim.py new file mode 100644 index 00000000..f9a5c4c8 --- /dev/null +++ b/garden_ai/benchmarks/matbench_discovery/examples/matbench_mattersim.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +"""Matbench Discovery Benchmark - MatterSim Example""" + +from garden_ai.benchmarks.matbench_discovery import MatbenchDiscovery + + +def create_mattersim_model(device): + from mattersim.forcefield import MatterSimCalculator + + return MatterSimCalculator(device=device) + + +output = MatbenchDiscovery.IS2RE.remote( + endpoint="anvil", + account="your-account-here", + model_factory=create_mattersim_model, + model_packages="mattersim", + num_structures="random_100", +) + +if "error" in output.get("metrics", {}): + print(f"Error: {output['metrics']['error']}") +else: + print("Benchmark Results:", output) diff --git a/garden_ai/benchmarks/matbench_discovery/examples/matbench_sevennet.py b/garden_ai/benchmarks/matbench_discovery/examples/matbench_sevennet.py new file mode 100644 index 00000000..da69d7ab --- /dev/null +++ b/garden_ai/benchmarks/matbench_discovery/examples/matbench_sevennet.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +"""Matbench Discovery Benchmark - SevenNet Example""" + +from garden_ai.benchmarks.matbench_discovery import MatbenchDiscovery + + +def create_sevennet_model(device): + from sevenn.calculator import SevenNetCalculator + + return SevenNetCalculator(model="7net-0", device=device) + + +output = MatbenchDiscovery.IS2RE.remote( + endpoint="anvil", + account="your-account-here", + model_factory=create_sevennet_model, + model_packages="sevenn", + num_structures="random_100", +) + +if "error" in output.get("metrics", {}): + print(f"Error: {output['metrics']['error']}") +else: + print("Benchmark Results:", output) diff --git a/garden_ai/benchmarks/matbench_discovery/tasks.py b/garden_ai/benchmarks/matbench_discovery/tasks.py new file mode 100644 index 00000000..cdf44cca --- /dev/null +++ b/garden_ai/benchmarks/matbench_discovery/tasks.py @@ -0,0 +1,1457 @@ +# /// script +# requires-python = "==3.12" +# dependencies = [ +# "groundhog-hpc", +# "ase", +# "numpy", +# "pandas", +# "scikit-learn", +# "torch", +# "matbench-discovery", +# "bibtexparser<1.4.3", +# ] +# +# [tool.hog.anvil] +# endpoint = "5aafb4c1-27b2-40d8-a038-a0277611868f" +# qos = "gpu" +# partition = "gpu" +# cores_per_node = 16 +# mem_per_mode = 32 +# scheduler_options = "#SBATCH --gpus-per-node=4\n" +# requirements = "" +# +# [tool.hog.sophia] +# endpoint = "8d07224c-ceaa-4b7f-946d-fae3f7423d5b" +# account = "Garden-Ai" +# queue = "by-gpu" +# /// + +from __future__ import annotations + +import concurrent.futures +import json +import logging +import multiprocessing +import os +import random +import sys +import time +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence + +import groundhog_hpc as hog +import numpy as np + +if TYPE_CHECKING: + import pandas as pd + + +class DatasetSize(str, Enum): + """Predefined dataset sizes for Matbench Discovery benchmarks. + + These correspond to different subsets of the WBM test set that are commonly + used for evaluating materials discovery models. + """ + + FULL = "full" + """Full WBM test set (~257k structures)""" + + UNIQUE_PROTOS = "unique_protos" + """Unique prototypes subset (~215k structures) - removes duplicate prototypes""" + + RANDOM_10K = "random_10k" + """Random 10k structures from the unique prototypes subset (fixed seed)""" + + RANDOM_100 = "random_100" + """Random 100 structures for quick testing (fixed seed)""" + + def seed(self, seed: int) -> "DatasetConfig": + """Return a configuration with a custom random seed.""" + return DatasetConfig(self, seed) + + +class DatasetConfig: + """Configuration for a dataset subset with a specific random seed.""" + + def __init__(self, subset: DatasetSize, seed: int): + self.subset = subset + self.seed = seed + + def __repr__(self): + return f"{self.subset.name}(seed={self.seed})" + + +def setup_logging(): + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] [PID:%(process)d] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + stream=sys.stdout, + force=True, + ) + return logging.getLogger("benchmark_runner") + + +def setup_device(gpu_id: Optional[int] = None) -> str: + """Setup compute device for this process.""" + try: + import torch + + if torch.cuda.is_available(): + return f"cuda:{gpu_id}" if gpu_id is not None else "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + except ImportError: + pass + return "cpu" + + +def convert_numpy_types(obj): + """Convert numpy types to Python native types for JSON serialization.""" + if isinstance(obj, (np.integer, np.floating)): + return obj.item() + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, dict): + return {k: convert_numpy_types(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_numpy_types(item) for item in obj] + return obj + + +# Meta metrics functions - will be injected from source for remote execution +get_hardware_info = None +extract_model_info = None +calculate_run_metadata = None + + +def _inject_meta_metrics(source: str) -> None: + """Inject meta_metrics functions from source code for remote execution.""" + global get_hardware_info, extract_model_info, calculate_run_metadata + namespace: Dict[str, Any] = {} + exec(source, namespace) + get_hardware_info = namespace["get_hardware_info"] + extract_model_info = namespace["extract_model_info"] + calculate_run_metadata = namespace["calculate_run_metadata"] + + +def _get_meta_metrics_source() -> str: + """Get source code of meta_metrics module (called locally).""" + import inspect + + from garden_ai.benchmarks.utils import meta_metrics + + return inspect.getsource(meta_metrics) + + +_MODEL_CACHE = None + + +# Helper functions from matbench-discovery/metrics/geo_opt.py and phonons.py + + +def calc_rmsd( + coords_true: np.ndarray, + coords_pred: np.ndarray, +) -> float: + """Calculate the Root Mean Square Deviation (RMSD) between two sets of coordinates. + Assumes atoms are in the same order. + """ + return np.sqrt(((coords_true - coords_pred) ** 2).mean()) + + +# Metrics calculations lifted from https://github.com/janosh/matbench-discovery/blob/main/matbench_discovery/metrics/discovery.py +def classify_stable( + each_true: Sequence[float] | pd.Series | np.ndarray, + each_pred: Sequence[float] | pd.Series | np.ndarray, + *, + stability_threshold: float = 0.0, + fillna: bool = True, +) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]: + import pandas as pd + + if len(each_true) != len(each_pred): + raise ValueError(f"{len(each_true)=} != {len(each_pred)=}") + + each_true_arr, each_pred_arr = pd.Series(each_true), pd.Series(each_pred) + + if stability_threshold is None or np.isnan(stability_threshold): + raise ValueError("stability_threshold must be a real number") + actual_pos = each_true_arr <= (stability_threshold or 0) + actual_neg = each_true_arr > (stability_threshold or 0) + + model_pos = each_pred_arr <= (stability_threshold or 0) + model_neg = each_pred_arr > (stability_threshold or 0) + + if fillna: + nan_mask = np.isnan(each_pred) + model_pos[nan_mask] = False + model_neg[nan_mask] = True + + n_pos, n_neg, total = model_pos.sum(), model_neg.sum(), len(each_pred) + if n_pos + n_neg != total: + raise ValueError( + f"after filling NaNs, the sum of positive ({n_pos}) and negative " + f"({n_neg}) predictions should add up to {total=}" + ) + + true_pos = actual_pos & model_pos + false_neg = actual_pos & model_neg + false_pos = actual_neg & model_pos + true_neg = actual_neg & model_neg + + return true_pos, false_neg, false_pos, true_neg + + +# This is also coptied from the matbench-discovery repo +def stable_metrics( + each_true: Sequence[float] | pd.Series | np.ndarray, + each_pred: Sequence[float] | pd.Series | np.ndarray, + *, + stability_threshold: float = 0.0, + fillna: bool = True, + prevalence: float | None = None, +) -> dict[str, float]: + n_true_pos, n_false_neg, n_false_pos, n_true_neg = map( + sum, + classify_stable( + each_true, each_pred, stability_threshold=stability_threshold, fillna=fillna + ), + ) + + n_total_pos = n_true_pos + n_false_neg + n_total_neg = n_true_neg + n_false_pos + if prevalence is None: + prevalence = ( + n_total_pos / (n_total_pos + n_total_neg) + if (n_total_pos + n_total_neg) > 0 + else float("nan") + ) + precision = ( + n_true_pos / (n_true_pos + n_false_pos) + if (n_true_pos + n_false_pos) > 0 + else float("nan") + ) + recall = n_true_pos / n_total_pos if n_total_pos > 0 else float("nan") + + TPR = recall + FPR = n_false_pos / n_total_neg if n_total_neg > 0 else float("nan") + TNR = n_true_neg / n_total_neg if n_total_neg > 0 else float("nan") + FNR = n_false_neg / n_total_pos if n_total_pos > 0 else float("nan") + + if FPR > 0 and TNR > 0 and FPR + TNR != 1: + if abs(FPR + TNR - 1) > 1e-6: + raise ValueError(f"{FPR=} {TNR=} don't add up to 1") + + if TPR > 0 and FNR > 0 and TPR + FNR != 1: + if abs(TPR + FNR - 1) > 1e-6: + raise ValueError(f"{TPR=} {FNR=} don't add up to 1") + + is_nan = np.isnan(each_true) | np.isnan(each_pred) + each_true, each_pred = np.array(each_true)[~is_nan], np.array(each_pred)[~is_nan] + + if precision + recall == 0: + f1_score = float("nan") + else: + f1_score = 2 * (precision * recall) / (precision + recall) + + from sklearn.metrics import r2_score # type: ignore + + # Return the standard discovery metrics + return dict( + F1=f1_score, + DAF=precision / prevalence if prevalence > 0 else float("nan"), + Precision=precision, + Recall=recall, + Accuracy=( + (n_true_pos + n_true_neg) / (n_total_pos + n_total_neg) + if (n_total_pos + n_total_neg > 0) + else float("nan") + ), + TPR=TPR, + FPR=FPR, + TNR=TNR, + FNR=FNR, + MAE=np.abs(each_true - each_pred).mean(), + RMSE=((each_true - each_pred) ** 2).mean() ** 0.5, + **{ + "R^2": r2_score(each_true, each_pred) + if len(each_true) > 1 + else float("nan") + }, + ) + + +def _process_batch_common( + batch_id: int, + structures: List[Any], + model_config: Dict[str, Any], + num_threads: int, + compute_fn: Callable[[Any, Any], Dict[str, Any]], + task_name: str, + model_factory_source: str, +) -> Dict[str, Any]: + import logging + import os + import re + import time + + gpu_id = model_config.get("gpu_id") + if gpu_id is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + device = "cuda" + else: + device = setup_device(gpu_id) + + import torch + + os.environ["OMP_NUM_THREADS"] = str(num_threads) + torch.set_num_threads(num_threads) + + worker_logger = logging.getLogger(f"worker_{batch_id}") + worker_logger.info( + f"Started {task_name} on {device} with {len(structures)} structures. Threads: {num_threads}" + ) + + global _MODEL_CACHE + try: + if _MODEL_CACHE is None: + # Reconstruct model_factory from source code + func_name_match = re.search(r"def\s+(\w+)\s*\(", model_factory_source) + if not func_name_match: + raise ValueError( + "Could not extract function name from model_factory source" + ) + func_name = func_name_match.group(1) + + # Execute the source to define the function + local_namespace: Dict[str, Any] = {} + exec(model_factory_source, local_namespace) + model_factory = local_namespace[func_name] + + model = model_factory(device) + _MODEL_CACHE = model + else: + model = _MODEL_CACHE + except Exception as e: + worker_logger.error(f"Failed to initialize model: {e}") + raise RuntimeError(f"Model initialization failed: {e}") from e + + results = {} + batch_start = time.time() + + for i, (struct_id, atoms) in enumerate(structures): + try: + result = compute_fn(model, atoms) + results[struct_id] = result + + if (i + 1) % 10 == 0: + elapsed = time.time() - batch_start + rate = (i + 1) / elapsed if elapsed > 0 else 0 + worker_logger.info( + f"Progress: {i + 1}/{len(structures)} ({rate:.2f} struct/s)" + ) + + except Exception as e: + worker_logger.warning(f"Structure {struct_id} failed: {e}") + results[struct_id] = {"error": str(e)} + + return results + + +def get_material_ids_for_subset( + subset_type: str, seed: int = 42 +) -> Optional[List[str]]: + if subset_type == "full": + return None + + import pandas as pd + from matbench_discovery.data import DataFiles # type: ignore + + df = pd.read_csv(DataFiles.wbm_summary.path) + + if subset_type == "unique_protos": + df_filtered = df.query("unique_prototype") + return df_filtered["material_id"].tolist() + + elif subset_type == "random_10k": + df_filtered = df.query("unique_prototype") + df_sampled = df_filtered.sample(n=10000, random_state=seed) + return df_sampled["material_id"].tolist() + + elif subset_type == "random_100": + df_filtered = df.query("unique_prototype") + df_sampled = df_filtered.sample(n=100, random_state=seed) + return df_sampled["material_id"].tolist() + + else: + raise ValueError(f"Unknown subset_type: {subset_type}") + + +def _load_dataset_common( + config: Dict[str, Any], + zip_path: str, + read_format: str = "extxyz", + read_index: Optional[str | slice] = None, +) -> List[Any]: + from io import TextIOWrapper + from zipfile import ZipFile + + from ase.io import read + + dataset_subset = config.get("dataset_subset", "full") + dataset_seed = config.get("dataset_seed", 42) + mat_ids = get_material_ids_for_subset(dataset_subset, seed=dataset_seed) + + structures = [] + + with ZipFile(zip_path, "r") as zf: + if mat_ids is None: + file_list = sorted( + zf.namelist(), + key=lambda x: int(x.split(".")[0]) + if x.split(".")[0].isdigit() + else float("inf"), + ) + # Only limit structures if explicitly specified (not when using full dataset) + if "num_structures" in config: + num_structures = config["num_structures"] + if isinstance(num_structures, int): + file_list = file_list[:num_structures] + else: + mat_id_set = set(mat_ids) + file_list = [ + f for f in zf.namelist() if f.replace(".extxyz", "") in mat_id_set + ] + + for filename in file_list: + with zf.open(filename) as f: + text_stream = TextIOWrapper(f, encoding="utf-8") + if read_index is not None: + atoms_list = read(text_stream, format=read_format, index=read_index) + if isinstance(atoms_list, list) and atoms_list: + structures.append((filename, atoms_list[-1])) + elif not isinstance(atoms_list, list): + structures.append((filename, atoms_list)) + else: + structures.append( + (filename, read(text_stream, format=read_format)) # type: ignore[arg-type] + ) + + return structures + + +# Task-specific helpers +def process_batch_relaxation( + batch_id: int, + structures: List[Any], + model_config: Dict[str, Any], + num_threads: int, + model_factory_source: str, +) -> Dict[str, Any]: + from ase.optimize import FIRE + + def compute(model, atoms): + atoms.calc = model + opt = FIRE(atoms, logfile=None) + opt.run(fmax=0.05, steps=500) + energy = atoms.get_potential_energy() + return {"energy": energy} + + return _process_batch_common( + batch_id, + structures, + model_config, + num_threads, + compute, + "relaxation", + model_factory_source, + ) + + +def process_batch_static( + batch_id: int, + structures: List[Any], + model_config: Dict[str, Any], + num_threads: int, + model_factory_source: str, +) -> Dict[str, Any]: + def compute(model, atoms): + atoms.calc = model + energy = atoms.get_potential_energy() + return {"energy": energy} + + return _process_batch_common( + batch_id, + structures, + model_config, + num_threads, + compute, + "static calculation", + model_factory_source, + ) + + +def process_batch_forces( + batch_id: int, + structures: List[Any], + model_config: Dict[str, Any], + num_threads: int, + model_factory_source: str, +) -> Dict[str, Any]: + def compute(model, atoms): + atoms.calc = model + energy = atoms.get_potential_energy() + forces = atoms.get_forces().tolist() + stress = atoms.get_stress().tolist() + return {"energy": energy, "forces": forces, "stress": stress} + + return _process_batch_common( + batch_id, + structures, + model_config, + num_threads, + compute, + "forces calculation", + model_factory_source, + ) + + +def load_dataset_wbm_initial(config: Dict[str, Any]) -> List[Any]: + from matbench_discovery.data import DataFiles # type: ignore + + return _load_dataset_common(config, DataFiles.wbm_initial_atoms.path) + + +def load_dataset_wbm_relaxed(config: Dict[str, Any]) -> List[Any]: + from matbench_discovery.data import DataFiles # type: ignore + + return _load_dataset_common(config, DataFiles.wbm_relaxed_atoms.path) + + +def load_dataset_mp_trj(config: Dict[str, Any]) -> List[Any]: + from matbench_discovery.data import DataFiles # type: ignore + + return _load_dataset_common(config, DataFiles.mp_trj_extxyz.path, read_index=":") + + +def calculate_metrics_energy( + results: Dict[str, Any], config: Dict[str, Any] +) -> Dict[str, Any]: + from io import TextIOWrapper + from zipfile import ZipFile + + from ase.io import read + from matbench_discovery.data import DataFiles, df_wbm # type: ignore + + if len(results) == 0: + return {"error": "No results to evaluate"} + + model_energies = {} + rmsd_list = [] + + # Calculate RMSD if positions are returned (e.g. for IS2RE) + try: + # Check if any result has positions + first_res = next(iter(results.values())) + if isinstance(first_res, dict) and "positions" in first_res: + with ZipFile(DataFiles.wbm_relaxed_atoms.path, "r") as zf: + for sid, res in results.items(): + if isinstance(res, dict) and "positions" in res: + try: + # Load GT structure + # sid is the filename in the zip (e.g. "material_id.extxyz") + with zf.open(sid) as f: + text_stream = TextIOWrapper(f, encoding="utf-8") + # Read first frame (should be only one for relaxed) + gt_atoms = read(text_stream, format="extxyz") + + pred_pos = np.array(res["positions"]) + gt_pos = gt_atoms.get_positions() # type: ignore + + if pred_pos.shape == gt_pos.shape: + # Use helper function + rmsd = calc_rmsd(gt_pos, pred_pos) + rmsd_list.append(rmsd) + except Exception: + pass + except Exception as e: + print(f"Warning: RMSD calculation failed: {e}") + + for sid, res in results.items(): + if isinstance(res, dict) and res.get("energy") is not None: + mat_id = sid.replace(".extxyz", "") + model_energies[mat_id] = res["energy"] + + if not model_energies: + return {"error": "No valid energies found in results"} + + df_wbm_indexed = df_wbm.set_index("material_id") + common_ids = list(set(model_energies.keys()) & set(df_wbm_indexed.index)) + + if not common_ids: + return {"error": "No matching IDs between results and ground truth"} + + df_subset = df_wbm_indexed.loc[common_ids] + y_pred = np.array([model_energies[mid] for mid in common_ids]) + n_atoms = df_subset["n_sites"].values + + # CRITICAL FIX: Compute formation energy error, not total energy error + # Formation energy is defined as: E_formation = E_total - Σ(n_i × E_ref_i) + # where E_ref_i are elemental reference energies in their standard states + + # Get ground truth formation energy per atom (uncorrected, matches model prediction level) + y_true_form = df_subset["e_form_per_atom_uncorrected"].values # eV/atom + + # Compute reference energy per atom from known DFT data + # E_ref_per_atom = E_total_per_atom - E_form_per_atom + y_true_total = df_subset["uncorrected_energy"].values + ref_energy_per_atom = (y_true_total / n_atoms) - y_true_form + + # Compute model's predicted formation energy per atom + # E_form_pred = E_total_pred / n_atoms - E_ref_per_atom + y_pred_form = (y_pred / n_atoms) - ref_energy_per_atom + + # Formation energy error (this is what affects stability predictions!) + e_form_error = y_pred_form - y_true_form + + # Predict energy above hull by adding formation energy error to ground truth hull distance + each_true = df_subset["e_above_hull_mp2020_corrected_ppd_mp"].values + each_pred = each_true + e_form_error + + df_unique = df_wbm.query("unique_prototype") + stable_count = (df_unique["e_above_hull_mp2020_corrected_ppd_mp"] <= 0).sum() + global_prevalence = stable_count / len(df_unique) + + metrics = stable_metrics(each_true, each_pred, prevalence=global_prevalence) + metrics["num_evaluated"] = len(common_ids) + + # Inject RMSD + metrics["RMSD"] = float(np.mean(rmsd_list)) if rmsd_list else float("nan") + + return metrics + + +def calculate_metrics_forces( + results: Dict[str, Any], config: Dict[str, Any] +) -> Dict[str, Any]: + from io import TextIOWrapper + from zipfile import ZipFile + + from ase.io import read + from matbench_discovery.data import DataFiles # type: ignore + + # We will use the standard stable_metrics for energy predictions in the trajectory + all_e_pred: List[float] = [] + all_e_true: List[float] = [] + + zip_path = DataFiles.mp_trj_extxyz.path + + with ZipFile(zip_path, "r") as zf: + for sid, res in results.items(): + if "error" in res: + continue + try: + if isinstance(res, dict) and "energy" in res: + with zf.open(sid) as f: + text_stream = TextIOWrapper(f, encoding="utf-8") + atoms_list = read(text_stream, format="extxyz", index=":") + gt_atoms = atoms_list[-1] # type: ignore + + e_pred = res["energy"] + e_true = gt_atoms.get_potential_energy() # type: ignore + n_atoms = len(gt_atoms) # type: ignore + + # Normalize per atom + all_e_pred.append(e_pred / n_atoms) + all_e_true.append(e_true / n_atoms) + except Exception: + pass + + if not all_e_true: + return {"error": "No valid energy comparisons found"} + + each_true = np.array(all_e_true) + each_pred = np.array(all_e_pred) + + # Calculate standard discovery metrics on energies + metrics = stable_metrics(each_true, each_pred) + metrics["num_evaluated"] = len(all_e_true) + return metrics + + +def run_benchmark_hog( + config: Dict[str, Any], + model_packages: str | List[str], + model_factory_source: str, + meta_metrics_source: str, + load_dataset_fn: Any, + process_fn: Any, + calc_metrics_fn: Any, +) -> Dict[str, Any]: + logger = setup_logging() + logger.info("Starting benchmark runner...") + + # Inject meta_metrics functions from source + _inject_meta_metrics(meta_metrics_source) + + # Collect hardware and model info + assert get_hardware_info is not None, "meta_metrics not injected" + assert extract_model_info is not None, "meta_metrics not injected" + hardware_info = get_hardware_info() + model_info = extract_model_info(model_packages) + logger.info(f"Hardware: {hardware_info}") + logger.info(f"Model: {model_info}") + + # Install model packages if specified + if model_packages: + import subprocess + + packages = ( + model_packages if isinstance(model_packages, list) else [model_packages] + ) + logger.info(f"Installing model packages: {packages}") + try: + result = subprocess.run( + ["uv", "pip", "install"] + packages, + capture_output=True, + text=True, + timeout=300, # 5 minute timeout + ) + if result.returncode != 0: + error_msg = ( + f"Failed to install model packages: {packages}\n" + f"stdout: {result.stdout}\n" + f"stderr: {result.stderr}" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + logger.info("Model packages installed successfully") + except subprocess.TimeoutExpired: + error_msg = f"Model package installation timed out after 300s: {packages}" + logger.error(error_msg) + raise RuntimeError(error_msg) + except Exception as e: + if isinstance(e, RuntimeError): + raise # Re-raise our own errors + error_msg = f"Could not install model packages: {e}" + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + # Fix SSL certificate issues on HPC nodes using certifi + try: + import ssl + + import certifi + + os.environ["SSL_CERT_FILE"] = certifi.where() + os.environ["REQUESTS_CA_BUNDLE"] = certifi.where() + ssl._create_default_https_context = ssl.create_default_context + logger.info(f"SSL certificates configured: {certifi.where()}") + except ImportError: + logger.warning("certifi not available, SSL issues may occur") + + checkpoint_path = config.get("checkpoint_path") + results = {} + prior_elapsed = 0.0 # Cumulative time from previous sessions + + if checkpoint_path and os.path.exists(checkpoint_path): + logger.info(f"Loading checkpoint from {checkpoint_path}") + try: + with open(checkpoint_path) as f: + checkpoint_data = json.load(f) + + # Handle new format with metadata vs old format (plain results dict) + if "_checkpoint_meta" in checkpoint_data: + results = checkpoint_data.get("results", {}) + meta = checkpoint_data["_checkpoint_meta"] + prior_elapsed = meta.get("elapsed_seconds", 0.0) + logger.info( + f"Found {len(results)} processed items in checkpoint " + f"(prior elapsed: {prior_elapsed:.1f}s)" + ) + else: + # Backward compatibility: old format is plain results dict + results = checkpoint_data + logger.info( + f"Found {len(results)} processed items in checkpoint (legacy format)" + ) + except Exception as e: + logger.warning(f"Failed to load checkpoint: {e}. Starting fresh.") + + try: + all_items = load_dataset_fn(config) + logger.info(f"Loaded {len(all_items)} total items") + except Exception as e: + logger.error(f"Failed to load dataset: {e}") + import traceback + + traceback.print_exc() + raise + + items_to_process = [ + (item_id, item) for item_id, item in all_items if str(item_id) not in results + ] + + try: + import torch + + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + num_gpus = 1 + else: + num_gpus = 0 + except ImportError: + num_gpus = 0 + + use_multi_gpu = config.get("use_multi_gpu", True) and num_gpus > 1 + # Use sched_getaffinity to get cores available to this job, not total cores on node + try: + total_cores = len(os.sched_getaffinity(0)) # type: ignore[attr-defined] + except AttributeError: + # Fallback for systems without sched_getaffinity (e.g., macOS) + total_cores = os.cpu_count() or 1 + num_workers = num_gpus if use_multi_gpu else 1 + available_cores = max(1, total_cores - 2) if total_cores > 4 else total_cores + threads_per_worker = max(1, available_cores // num_workers) + + # MPS (Apple Silicon) performance degrades with high thread counts due to contention + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + threads_per_worker = 1 + + logger.info( + f"Resources: {num_gpus} GPUs, {total_cores} Cores. Using {num_workers} workers ({threads_per_worker} threads/worker)" + ) + + if not items_to_process: + logger.info( + "All items already processed! Calculating metrics from checkpoint..." + ) + + # Calculate metrics from checkpoint results + try: + metrics = calc_metrics_fn(results, config) + logger.info(f"Metrics calculated: {metrics}") + except Exception as e: + logger.error(f"Failed to calculate metrics: {e}") + import traceback + + traceback.print_exc() + metrics = {"error": f"Metrics calculation failed: {e}"} + + # Use cumulative values from checkpoint metadata + assert calculate_run_metadata is not None, "meta_metrics not injected" + run_metadata = calculate_run_metadata( + hardware_info=hardware_info, + model_info=model_info, + total_elapsed=prior_elapsed, + num_workers=num_workers, + num_structures_total=len(all_items), + num_structures_processed=len(results), + ) + return {"metrics": metrics, "run_metadata": run_metadata} + + logger.info(f"Processing {len(items_to_process)} remaining items") + + random.seed(42) + random.shuffle(items_to_process) + + start_time = time.time() + chunk_size = 1000 * num_workers + chunks = [ + items_to_process[i : i + chunk_size] + for i in range(0, len(items_to_process), chunk_size) + ] + + ctx = multiprocessing.get_context("spawn") + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers, mp_context=ctx + ) as executor: + for chunk_idx, chunk in enumerate(chunks): + chunk_start = time.time() + logger.info( + f"Starting chunk {chunk_idx + 1}/{len(chunks)} ({len(chunk)} items)" + ) + + futures = [] + batch_size = (len(chunk) + num_workers - 1) // num_workers + + for i in range(num_workers): + start = i * batch_size + end = min((i + 1) * batch_size, len(chunk)) + if start < end: + batch = chunk[start:end] + worker_config = config.copy() + worker_config["gpu_id"] = i if use_multi_gpu else None + futures.append( + executor.submit( + process_fn, + i, + batch, + worker_config, + threads_per_worker, + model_factory_source, + ) + ) + + chunk_results = {} + for future in concurrent.futures.as_completed(futures): + try: + batch_res = future.result() + chunk_results.update(batch_res) + except Exception as e: + logger.error(f"Worker failed in chunk {chunk_idx}: {e}") + raise RuntimeError( + "Aborting benchmark due to worker failure" + ) from e + + results.update(chunk_results) + + if checkpoint_path: + try: + tmp_path = checkpoint_path + ".tmp" + # Calculate cumulative elapsed time for checkpoint + current_elapsed = time.time() - start_time + cumulative_elapsed = prior_elapsed + current_elapsed + + # Save checkpoint with metadata for resume + checkpoint_data = { + "results": convert_numpy_types(results), + "_checkpoint_meta": { + "elapsed_seconds": cumulative_elapsed, + "structures_processed": len(results), + }, + } + with open(tmp_path, "w") as f: + json.dump(checkpoint_data, f, indent=2) + os.replace(tmp_path, checkpoint_path) + logger.info(f"Checkpoint saved to {checkpoint_path}") + except Exception as e: + logger.error(f"Failed to save checkpoint: {e}") + raise RuntimeError( + f"Critical: Failed to save checkpoint to {checkpoint_path}. " + f"Aborting to prevent loss of progress. Error: {e}" + ) from e + + elapsed = time.time() - chunk_start + logger.info(f"Chunk {chunk_idx + 1} complete in {elapsed:.1f}s") + + session_elapsed = time.time() - start_time + total_elapsed = prior_elapsed + session_elapsed + logger.info( + f"Session complete in {session_elapsed:.1f}s. " + f"Total elapsed: {total_elapsed:.1f}s." + ) + + logger.info("Calculating metrics...") + try: + metrics = calc_metrics_fn(results, config) + logger.info(f"Metrics calculated: {metrics}") + except Exception as e: + logger.error(f"Failed to calculate metrics: {e}") + import traceback + + traceback.print_exc() + metrics = {"error": f"Metrics calculation failed: {e}"} + + # Calculate run metadata using cumulative values + assert calculate_run_metadata is not None, "meta_metrics not injected" + run_metadata = calculate_run_metadata( + hardware_info=hardware_info, + model_info=model_info, + total_elapsed=total_elapsed, + num_workers=num_workers, + num_structures_total=len(all_items), + num_structures_processed=len(results), + ) + logger.info(f"Run metadata: {run_metadata}") + + output = {"metrics": metrics, "run_metadata": run_metadata} + output = convert_numpy_types(output) + return output + + +class BenchmarkMethod: + """Wrapper around groundhog Method that handles source extraction for remote execution.""" + + BENCHMARK_NAME = "matbench_discovery" + + def __init__(self, hog_method, task_name: str): + """Initialize wrapper with the underlying groundhog Method. + + Args: + hog_method: The underlying groundhog method to wrap. + task_name: Name of the benchmark task (e.g., 'IS2RE', 'S2EFS'). + """ + self._hog_method = hog_method + self._task_name = task_name + + def _extract_sources(self, kwargs): + """Extract source code from model_factory and meta_metrics for remote execution.""" + import inspect + + # Extract model_factory source + if "model_factory" in kwargs: + factory = kwargs["model_factory"] + if callable(factory) and not isinstance(factory, str): + try: + kwargs["model_factory"] = inspect.getsource(factory) + except (OSError, TypeError) as e: + raise ValueError( + f"Could not extract source code from model_factory. " + f"Ensure the function is defined in a file (not interactive/lambda). " + f"Error: {e}" + ) + + # Extract meta_metrics source (runs locally where garden_ai is available) + kwargs["meta_metrics_source"] = _get_meta_metrics_source() + + return kwargs + + def _get_checkpoint_info_for_display(self, kwargs, is_remote: bool): + """Get checkpoint information to display to the user. + + Args: + kwargs: Method keyword arguments + is_remote: True if this is a remote/submit call, False for local + + Returns: + Tuple of (display_message, checkpoint_identifier, is_resuming) + """ + checkpoint_path = kwargs.get("checkpoint_path") + checkpoint_name = kwargs.get("checkpoint_name") + + if checkpoint_path: + # User provided explicit path + if is_remote: + msg = f"Resuming from checkpoint on remote system: {checkpoint_path}" + else: + msg = f"Resuming from checkpoint: {checkpoint_path}" + return msg, checkpoint_path, True + + # Generate checkpoint name + if not checkpoint_name: + model_packages = kwargs.get("model_packages", "") + num_structures = kwargs.get("num_structures", "full") + + # Determine subset string for checkpoint name + subset = "full" + if isinstance(num_structures, str): + subset = num_structures + elif hasattr(num_structures, "value"): # DatasetSize enum + subset = num_structures.value + elif hasattr(num_structures, "subset"): # DatasetConfig + subset = num_structures.subset.value + elif isinstance(num_structures, int): + subset = "full" if num_structures >= 200000 else f"num_{num_structures}" + + # Extract model name from packages + model_str = "unknown" + if isinstance(model_packages, list): + model_str = "_".join( + pkg.split("/")[-1].split("@")[0] for pkg in model_packages[:2] + ) + elif isinstance(model_packages, str): + model_str = model_packages.split("/")[-1].split("@")[0] + + # Generate timestamp and uuid like in _generate_checkpoint_name + import time + import uuid + + timestamp = time.strftime("%Y%m%d_%H%M%S") + short_uuid = str(uuid.uuid4())[:8] + checkpoint_name = ( + f"matbench_{model_str}_{subset}_{timestamp}_{short_uuid}.json" + ) + + # Construct display message + if is_remote: + msg = f"Checkpoint will be saved on remote system: ~/.garden/benchmarks/{checkpoint_name}" + identifier = f"~/.garden/benchmarks/{checkpoint_name}" + else: + local_path = os.path.expanduser(f"~/.garden/benchmarks/{checkpoint_name}") + msg = f"Checkpoint will be saved locally: {local_path}" + identifier = local_path + + return msg, identifier, False + + def _print_checkpoint_info(self, kwargs, is_remote: bool): + """Print checkpoint information before execution. + + Args: + kwargs: Method keyword arguments + is_remote: True if this is a remote/submit call, False for local + """ + msg, identifier, is_resuming = self._get_checkpoint_info_for_display( + kwargs, is_remote + ) + + print("=" * 80) + if is_resuming: + print(f"📂 {msg}") + else: + print(f"💾 {msg}") + + if is_remote: + print(" To resume this benchmark if it fails, use:") + print(f' checkpoint_path="{identifier}"') + print(" Note: Checkpoint is on the remote system, not your local machine") + else: + print(" To resume this benchmark if it fails, use:") + print(f' checkpoint_path="{identifier}"') + print("=" * 80) + + def _add_benchmark_info(self, result: Dict[str, Any]) -> Dict[str, Any]: + """Add benchmark metadata to the result for publishing.""" + if isinstance(result, dict): + result["_benchmark_info"] = { + "benchmark_name": self.BENCHMARK_NAME, + "task_name": self._task_name, + } + return result + + def remote(self, *args, **kwargs): + """Execute remotely with automatic source extraction.""" + kwargs = self._extract_sources(kwargs) + self._print_checkpoint_info(kwargs, is_remote=True) + result = self._hog_method.remote(*args, **kwargs) + return self._add_benchmark_info(result) + + def local(self, *args, **kwargs): + """Execute locally with automatic source extraction.""" + kwargs = self._extract_sources(kwargs) + self._print_checkpoint_info(kwargs, is_remote=False) + result = self._hog_method.local(*args, **kwargs) + return self._add_benchmark_info(result) + + def submit(self, *args, **kwargs): + """Submit for async execution with automatic source extraction.""" + kwargs = self._extract_sources(kwargs) + self._print_checkpoint_info(kwargs, is_remote=True) + return self._hog_method.submit(*args, **kwargs) + + def __call__(self, *args, **kwargs): + """Direct call (for local execution within groundhog).""" + return self._hog_method(*args, **kwargs) + + +class _MatbenchDiscoveryBase: + """Matbench Discovery tasks using Groundhog HPC.""" + + REPO_URL = "https://github.com/janosh/matbench-discovery" + REPO_REF = "main" + + @staticmethod + def _prepare_runner_config( + num_structures: int | "DatasetSize" | "DatasetConfig" | str, + repo_url: str = REPO_URL, + repo_ref: str = REPO_REF, + ) -> Dict[str, Any]: + """Prepare the runner configuration based on num_structures.""" + # Need to handle DatasetSize/Config which might be passed as objects or values + # Since we are in the script, we might not have the enums imported if they are not in this file. + # But the user passes them. + # If they are passed as arguments, they are serialized. + # We need to extract value. + + # Simple heuristic: if it has 'value' attr, use it. + subset = "full" + seed = 42 + + if isinstance(num_structures, str): + # String value like "random_100" - use directly as subset + subset = num_structures + elif hasattr(num_structures, "value"): # Enum + subset = num_structures.value + # Check for seed method/attr if it's our custom Config + if hasattr(num_structures, "seed"): + if callable(num_structures.seed): + pass # It's the method + else: + seed = num_structures.seed + elif hasattr(num_structures, "subset"): # DatasetConfig + subset = num_structures.subset.value # type: ignore[union-attr] + seed = num_structures.seed # type: ignore[union-attr] + elif isinstance(num_structures, int): + subset = "full" + # We handle int as limit in load_dataset + return { + "repo_url": repo_url, + "repo_ref": repo_ref, + "num_structures": num_structures, + "dataset_subset": "full", + } + + return { + "repo_url": repo_url, + "repo_ref": repo_ref, + "dataset_subset": subset, + "dataset_seed": seed, + } + + @staticmethod + def _generate_checkpoint_name( + model_packages: str | List[str], runner_config: Dict[str, Any] + ) -> str: + import time + import uuid + + model_str = ( + str(model_packages) + .replace("[", "") + .replace("]", "") + .replace("'", "") + .replace('"', "") + .replace(",", "_") + .replace(" ", "") + ) + subset_str = runner_config.get("dataset_subset", "custom") + timestamp = int(time.time()) + short_uuid = str(uuid.uuid4())[:8] + return f"matbench_{model_str}_{subset_str}_{timestamp}_{short_uuid}.json" + + @staticmethod + def _run_task( + model_factory: Any, + model_packages: str | List[str], + num_structures: int | str | DatasetSize | DatasetConfig, + checkpoint_name: str | None, + checkpoint_path: str | None, + process_fn: Any, + load_dataset_fn: Any, + calc_metrics_fn: Any, + sys_path: List[str] | None = None, + meta_metrics_source: str | None = None, + ) -> Dict[str, Any]: + import inspect + + # Handle model_factory as either a callable or source string + if isinstance(model_factory, str): + model_factory_source = model_factory + else: + try: + model_factory_source = inspect.getsource(model_factory) + except (OSError, TypeError) as e: + raise ValueError( + f"Could not extract source code from model_factory. " + f"For remote execution, use: inspect.getsource(your_factory). " + f"Error: {e}" + ) + + # Add custom sys.path if provided + if sys_path: + import sys + + for p in sys_path: + if p not in sys.path: + sys.path.append(p) + + runner_config = MatbenchDiscovery._prepare_runner_config(num_structures) + + if not checkpoint_name and not checkpoint_path: + checkpoint_name = MatbenchDiscovery._generate_checkpoint_name( + model_packages, runner_config + ) + + if checkpoint_path: + # Always expand tilde to home directory + final_checkpoint_path = os.path.expanduser(checkpoint_path) + if os.path.exists(final_checkpoint_path): + print(f"Resuming from checkpoint: {final_checkpoint_path}") + else: + print( + f"WARNING: Checkpoint file not found at {final_checkpoint_path}. " + f"Starting fresh and will save checkpoints to this path." + ) + # Ensure directory exists for new checkpoint + os.makedirs(os.path.dirname(final_checkpoint_path), exist_ok=True) + else: + print( + f"Checkpoint will be saved to: ~/.garden/benchmarks/{checkpoint_name}" + ) + final_checkpoint_path = os.path.expanduser( + f"~/.garden/benchmarks/{checkpoint_name}" + ) + os.makedirs(os.path.dirname(final_checkpoint_path), exist_ok=True) + + # Validate we can write to the checkpoint path early to fail fast + try: + test_file = final_checkpoint_path + ".write_test" + with open(test_file, "w") as f: + f.write("test") + os.remove(test_file) + except Exception as e: + raise RuntimeError( + f"Cannot write to checkpoint path: {final_checkpoint_path}. " + f"Check permissions and disk space. Error: {e}" + ) from e + + runner_config["checkpoint_path"] = final_checkpoint_path + + # meta_metrics_source is injected by BenchmarkMethod wrapper + if meta_metrics_source is None: + raise ValueError("meta_metrics_source required for benchmark execution") + + return run_benchmark_hog( + runner_config, + model_packages, + model_factory_source, + meta_metrics_source, + load_dataset_fn, + process_fn, + calc_metrics_fn, + ) + + @hog.method() + def IS2RE( + model_factory: Any, + model_packages: str | List[str], + num_structures: int | str | DatasetSize | DatasetConfig = "full", + checkpoint_name: str | None = None, + checkpoint_path: str | None = None, + sys_path: List[str] | None = None, + meta_metrics_source: str | None = None, + ) -> Dict[str, Any]: + """Initial Structure to Relaxed Energy.""" + return MatbenchDiscovery._run_task( + model_factory, + model_packages, + num_structures, + checkpoint_name, + checkpoint_path, + process_batch_relaxation, + load_dataset_wbm_initial, + calculate_metrics_energy, + sys_path=sys_path, + meta_metrics_source=meta_metrics_source, + ) + + @hog.method() + def RS2RE( + model_factory: Any, + model_packages: str | List[str], + num_structures: int | str | DatasetSize | DatasetConfig = "full", + checkpoint_name: str | None = None, + checkpoint_path: str | None = None, + sys_path: List[str] | None = None, + meta_metrics_source: str | None = None, + ) -> Dict[str, Any]: + """Relaxed Structure to Relaxed Energy.""" + return MatbenchDiscovery._run_task( + model_factory, + model_packages, + num_structures, + checkpoint_name, + checkpoint_path, + process_batch_static, + load_dataset_wbm_relaxed, + calculate_metrics_energy, + sys_path=sys_path, + meta_metrics_source=meta_metrics_source, + ) + + @hog.method() + def S2EFS( + model_factory: Any, + model_packages: str | List[str], + num_structures: int | str | DatasetSize | DatasetConfig = "full", + checkpoint_name: str | None = None, + checkpoint_path: str | None = None, + sys_path: List[str] | None = None, + meta_metrics_source: str | None = None, + ) -> Dict[str, Any]: + """Structure to Energy, Forces, Stress.""" + return MatbenchDiscovery._run_task( + model_factory, + model_packages, + num_structures, + checkpoint_name, + checkpoint_path, + process_batch_forces, + load_dataset_mp_trj, + calculate_metrics_forces, + sys_path=sys_path, + meta_metrics_source=meta_metrics_source, + ) + + # Aliases + @hog.method() + def S2EF(*args, **kwargs): + return _MatbenchDiscoveryBase.S2EFS(*args, **kwargs) + + @hog.method() + def S2EFSM(*args, **kwargs): + return _MatbenchDiscoveryBase.S2EFS(*args, **kwargs) + + @hog.method() + def IS2E(*args, **kwargs): + # IS2E is Initial Structure to Energy (Static). + return _MatbenchDiscoveryBase._run_task( + *args, + **kwargs, + process_fn=process_batch_static, + load_dataset_fn=load_dataset_wbm_initial, + calc_metrics_fn=calculate_metrics_energy, + ) + + @hog.method() + def S2E(*args, **kwargs): + # Structure to Energy (Relaxed Structure to Energy) -> RS2RE + return _MatbenchDiscoveryBase.RS2RE(*args, **kwargs) + + @hog.method() + def S2RE(*args, **kwargs): + # Structure to Relaxed Energy -> IS2RE + return _MatbenchDiscoveryBase.IS2RE(*args, **kwargs) + + @hog.method() + def RP2RE(*args, **kwargs): + return _MatbenchDiscoveryBase.IS2RE(*args, **kwargs) + + @hog.method() + def IP2E(*args, **kwargs): + return _MatbenchDiscoveryBase.IS2E(*args, **kwargs) + + +class MatbenchDiscovery: + """Matbench Discovery benchmark tasks. + + This class provides wrapped methods that automatically handle model_factory + source extraction for remote execution. Users can pass callable functions + directly without needing to call inspect.getsource() themselves. + + Example: + def create_mace_model(device): + from mace.calculators import mace_mp + return mace_mp(model="medium", device=device) + + results = MatbenchDiscovery.IS2RE.remote( + endpoint="your-endpoint-id", + model_factory=create_mace_model, + model_packages="mace-torch", + ) + """ + + REPO_URL = _MatbenchDiscoveryBase.REPO_URL + REPO_REF = _MatbenchDiscoveryBase.REPO_REF + + # Internal methods (needed for remote execution compatibility) + _prepare_runner_config = _MatbenchDiscoveryBase._prepare_runner_config + _generate_checkpoint_name = _MatbenchDiscoveryBase._generate_checkpoint_name + _run_task = _MatbenchDiscoveryBase._run_task + + # Main benchmark tasks - wrapped for automatic model_factory source extraction + IS2RE = BenchmarkMethod(_MatbenchDiscoveryBase.IS2RE, "IS2RE") + RS2RE = BenchmarkMethod(_MatbenchDiscoveryBase.RS2RE, "RS2RE") + S2EFS = BenchmarkMethod(_MatbenchDiscoveryBase.S2EFS, "S2EFS") + + # Aliases + S2EF = BenchmarkMethod(_MatbenchDiscoveryBase.S2EF, "S2EF") + S2EFSM = BenchmarkMethod(_MatbenchDiscoveryBase.S2EFSM, "S2EFSM") + IS2E = BenchmarkMethod(_MatbenchDiscoveryBase.IS2E, "IS2E") + S2E = BenchmarkMethod(_MatbenchDiscoveryBase.S2E, "S2E") + S2RE = BenchmarkMethod(_MatbenchDiscoveryBase.S2RE, "S2RE") + RP2RE = BenchmarkMethod(_MatbenchDiscoveryBase.RP2RE, "RP2RE") + IP2E = BenchmarkMethod(_MatbenchDiscoveryBase.IP2E, "IP2E") diff --git a/garden_ai/benchmarks/utils/meta_metrics.py b/garden_ai/benchmarks/utils/meta_metrics.py new file mode 100644 index 00000000..951d549f --- /dev/null +++ b/garden_ai/benchmarks/utils/meta_metrics.py @@ -0,0 +1,162 @@ +"""Meta-level benchmark metrics utilities. + +Shared utilities for collecting hardware info, estimating costs, and extracting +model metadata that can be reused across different benchmark implementations. +""" + +from __future__ import annotations + +from typing import Any, Dict, List + +# GPU hourly cost estimates (USD) - Modal pricing (https://modal.com/pricing) +GPU_HOURLY_COSTS = { + "B200": 6.25, # $0.001736/sec + "H200": 4.54, # $0.001261/sec + "H100": 3.95, # $0.001097/sec + "A100-80GB": 2.50, # $0.000694/sec (80GB variant) + "A100": 2.10, # $0.000583/sec (40GB variant) + "L40S": 1.95, # $0.000542/sec + "A10": 1.10, # $0.000306/sec + "L4": 0.80, # $0.000222/sec + "T4": 0.59, # $0.000164/sec + "default": 2.00, # Fallback for unknown GPUs +} + +# Model name inference from package names +MODEL_PACKAGE_NAMES = { + "mace": "MACE", + "mattersim": "MatterSim", + "sevennet": "SevenNet", + "chgnet": "CHGNet", + "equiformer": "EquiformerV2", + "orb": "Orb", + "m3gnet": "M3GNet", + "alignn": "ALIGNN", +} + + +def get_hardware_info() -> Dict[str, Any]: + """Collect hardware information about the execution environment. + + Returns: + Dictionary containing: + - device_type: "cuda", "mps", or "cpu" + - num_gpus: Number of GPUs available + - gpu_names: List of GPU names + - gpu_memory_gb: Memory of first GPU in GB (if available) + """ + info = {"device_type": "cpu", "num_gpus": 0, "gpu_names": [], "gpu_memory_gb": None} + try: + import torch + + if torch.cuda.is_available(): + info["device_type"] = "cuda" + num_gpus = torch.cuda.device_count() + info["num_gpus"] = num_gpus + info["gpu_names"] = [torch.cuda.get_device_name(i) for i in range(num_gpus)] + if num_gpus > 0: + props = torch.cuda.get_device_properties(0) + info["gpu_memory_gb"] = round(props.total_memory / (1024**3), 1) + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + info["device_type"] = "mps" + info["num_gpus"] = 1 + info["gpu_names"] = ["Apple Metal Performance Shaders"] + except ImportError: + pass + return info + + +def get_gpu_hourly_cost(gpu_name: str) -> float: + """Estimate hourly cost for a GPU based on its name. + + Args: + gpu_name: GPU name string (e.g., "NVIDIA A100-SXM4-40GB") + + Returns: + Estimated hourly cost in USD + """ + gpu_name_upper = gpu_name.upper() + for key in GPU_HOURLY_COSTS: + if key != "default" and key.upper() in gpu_name_upper: + return GPU_HOURLY_COSTS[key] + return GPU_HOURLY_COSTS["default"] + + +def extract_model_info(model_packages: str | List[str]) -> Dict[str, Any]: + """Extract model info from package specification. + + Args: + model_packages: Package name(s) used to install the model + + Returns: + Dictionary containing: + - model_name: Inferred model name or "unknown" + - model_packages: List of package names + """ + packages = model_packages if isinstance(model_packages, list) else [model_packages] + model_name = "unknown" + for pkg in packages: + pkg_lower = pkg.lower() + for key, name in MODEL_PACKAGE_NAMES.items(): + if key in pkg_lower: + model_name = name + break + if model_name != "unknown": + break + return {"model_name": model_name, "model_packages": packages} + + +def calculate_run_metadata( + hardware_info: Dict[str, Any], + model_info: Dict[str, Any], + total_elapsed: float, + num_workers: int, + num_structures_total: int, + num_structures_processed: int, +) -> Dict[str, Any]: + """Calculate run metadata including timing, cost, and hardware info. + + Args: + hardware_info: Output from get_hardware_info() + model_info: Output from extract_model_info() + total_elapsed: Total benchmark runtime in seconds + num_workers: Number of worker processes used + num_structures_total: Total structures in dataset + num_structures_processed: Structures processed in this run + + Returns: + Complete run_metadata dictionary + """ + throughput = num_structures_total / total_elapsed if total_elapsed > 0 else 0 + + # Calculate cost estimate + gpu_hourly_cost = ( + get_gpu_hourly_cost(hardware_info["gpu_names"][0]) + if hardware_info["gpu_names"] + else 0 + ) + total_gpu_hours = (total_elapsed / 3600) * num_workers + total_cost = total_gpu_hours * gpu_hourly_cost + cost_per_1k = ( + (total_cost / num_structures_total) * 1000 if num_structures_total > 0 else 0 + ) + + return { + "model": model_info, + "hardware": hardware_info, + "timing": { + "total_seconds": round(total_elapsed, 2), + "throughput_per_second": round(throughput, 3), + "num_workers": num_workers, + }, + "cost": { + "gpu_hourly_rate_usd": gpu_hourly_cost, + "total_gpu_hours": round(total_gpu_hours, 4), + "estimated_cost_usd": round(total_cost, 4), + "estimated_cost_per_1000_structures_usd": round(cost_per_1k, 4), + }, + "dataset": { + "num_structures_total": num_structures_total, + "num_structures_processed": num_structures_processed, + }, + } diff --git a/garden_ai/client.py b/garden_ai/client.py index 87b8e0e4..2b3ee425 100644 --- a/garden_ai/client.py +++ b/garden_ai/client.py @@ -23,9 +23,9 @@ from globus_sdk.authorizers import GlobusAuthorizer from globus_sdk.scopes import ScopeBuilder from globus_sdk.tokenstorage import SimpleJSONFileAdapter -from modal.cli._traceback import setup_rich_traceback from rich import print from rich.prompt import Prompt +from rich.traceback import install from garden_ai.backend_client import BackendClient from garden_ai.constants import GardenConstants @@ -34,8 +34,7 @@ from garden_ai.hpc.gardens.mlip_garden import MLIPGarden logger = logging.getLogger() -# modal helper replacement for rich.traceback.install -setup_rich_traceback() +install() class AuthException(Exception): diff --git a/garden_ai/py.typed b/garden_ai/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/garden_ai/schemas/benchmark.py b/garden_ai/schemas/benchmark.py new file mode 100644 index 00000000..283f7fd3 --- /dev/null +++ b/garden_ai/schemas/benchmark.py @@ -0,0 +1,36 @@ +"""Benchmark-related schemas for API requests/responses.""" + +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + + +class BenchmarkResultCreateRequest(BaseModel): + """Request schema for publishing benchmark results to the backend.""" + + benchmark_name: str = Field( + ..., + description="Name of the benchmark suite (e.g., 'matbench_discovery')", + ) + benchmark_task_name: str = Field( + ..., + description="Name of the specific task within the benchmark (e.g., 'IS2RE', 'S2EFS')", + ) + metrics: Dict[str, Any] = Field( + ..., + description="Dictionary of benchmark metrics (F1, DAF, MAE, etc.)", + ) + run_metadata: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional run metadata (hardware info, timing, cost estimates)", + ) + + +class BenchmarkResultResponse(BaseModel): + """Response schema from the benchmark result creation endpoint.""" + + id: int = Field(..., description="Unique identifier for the benchmark result") + benchmark_name: str + benchmark_task_name: str + metrics: Dict[str, Any] + run_metadata: Optional[Dict[str, Any]] = None diff --git a/pyproject.toml b/pyproject.toml index 0075640e..5e44603e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ # used transitively by modal -> grpclib, force 4.3.0 to reslove CVE-2025-57804 # Can remove once we upgrade to more current modal sdk version "h2>=4.3.0", - "groundhog-hpc>=0.5.0", + "groundhog-hpc>=0.5.6", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index f8fb93e5..d17d8a4f 100644 --- a/uv.lock +++ b/uv.lock @@ -1098,7 +1098,7 @@ requires-dist = [ { name = "gitpython", specifier = ">=3.1.35,<4.0.0" }, { name = "globus-compute-sdk", specifier = ">=4.0.0" }, { name = "globus-sdk", specifier = ">=3.34.0,<4.0.0" }, - { name = "groundhog-hpc", specifier = ">=0.5.0" }, + { name = "groundhog-hpc", specifier = ">=0.5.6" }, { name = "h2", specifier = ">=4.3.0" }, { name = "huggingface-hub", specifier = "==0.18.0" }, { name = "ipython", specifier = "<8.13" }, @@ -1251,7 +1251,7 @@ wheels = [ [[package]] name = "groundhog-hpc" -version = "0.5.4" +version = "0.5.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "globus-compute-sdk" }, @@ -1265,9 +1265,9 @@ dependencies = [ { name = "typer" }, { name = "uv" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ea/4a/79c3bef59e0e4e538875949cec290d26cf5cefd601135e190723f9fc89de/groundhog_hpc-0.5.4.tar.gz", hash = "sha256:1f9ef486a6b62a3f28168689425b9b838c1abe92d76291a392299c70a4f5a0ec", size = 31705, upload-time = "2025-11-06T23:31:08.795Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/e7/adf855aaded946d2cff12851320c7b53114fed40d4b833efaf0081bc3aea/groundhog_hpc-0.5.6.tar.gz", hash = "sha256:cc5a25c0dfc6a0ddc641e631cc7dae1466e81b8f24984f102eb8300cf6340b42", size = 32346, upload-time = "2025-12-09T18:49:49.554Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/c4/abece517b27357edc102891244233e68a12de57c0cee7bd4a404ad86bb74/groundhog_hpc-0.5.4-py3-none-any.whl", hash = "sha256:287c91211f2d64fb89b84b3be0cd26611cd3c1a29c76b83efe1219f3ad8fc53f", size = 44364, upload-time = "2025-11-06T23:31:07.512Z" }, + { url = "https://files.pythonhosted.org/packages/7c/13/702590a7f6064c01609379225c678b42e6c1a56e72e85a8d14a23ec9213a/groundhog_hpc-0.5.6-py3-none-any.whl", hash = "sha256:d6347031c1f779e24379fd9619ca59dc2dce8df521fcd0c6cb51b00a7e807eab", size = 45086, upload-time = "2025-12-09T18:49:50.346Z" }, ] [[package]]