Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get rid of SurrogateRunner and SurrogateBenchmarkProblem #2954

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions ax/benchmark/problems/surrogate.py

This file was deleted.

142 changes: 2 additions & 140 deletions ax/benchmark/runners/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,17 @@
# pyre-strict

from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from typing import Any
from dataclasses import dataclass

import torch
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.observation import ObservationFeatures
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TParamValue
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from botorch.utils.datasets import SupervisedDataset
from pyre_extensions import assert_is_instance, none_throws
from pyre_extensions import none_throws
from torch import Tensor


Expand Down Expand Up @@ -100,137 +96,3 @@ def __eq__(self, other: Base) -> bool:

# Don't check surrogate, datasets, or callable
return self.name == other.name


@dataclass
class SurrogateRunner(BenchmarkRunner):
"""Runner for surrogate benchmark problems.

Args:
name: The name of the runner.
outcome_names: The names of the outcomes of the Surrogate.
_surrogate: Either `None`, or a `TorchModelBridge` surrogate to use
for generating observations. If `None`, `get_surrogate_and_datasets`
must not be None and will be used to generate the surrogate when it
is needed.
_datasets: Either `None`, or the `SupervisedDataset`s used to fit
the surrogate model. If `None`, `get_surrogate_and_datasets` must
not be None and will be used to generate the datasets when they are
needed.
noise_stds: Noise standard deviations to add to the surrogate output(s).
If a single float is provided, noise with that standard deviation
is added to all outputs. Alternatively, a dictionary mapping outcome
names to noise standard deviations can be provided to specify different
noise levels for different outputs.
get_surrogate_and_datasets: Function that returns the surrogate and
datasets, to allow for lazy construction. If
`get_surrogate_and_datasets` is not provided, `surrogate` and
`datasets` must be provided, and vice versa.
search_space_digest: Used to get the target task and fidelity at
which the oracle is evaluated.
"""

name: str
_surrogate: TorchModelBridge | None = None
_datasets: list[SupervisedDataset] | None = None
noise_stds: float | dict[str, float] = 0.0
get_surrogate_and_datasets: (
None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]]
) = None
statuses: dict[int, TrialStatus] = field(default_factory=dict)

def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None:
super().__post_init__(search_space_digest=search_space_digest)
if self.get_surrogate_and_datasets is None and (
self._surrogate is None or self._datasets is None
):
raise ValueError(
"If `get_surrogate_and_datasets` is None, `_surrogate` "
"and `_datasets` must not be None, and vice versa."
)

def set_surrogate_and_datasets(self) -> None:
self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)()

@property
def surrogate(self) -> TorchModelBridge:
if self._surrogate is None:
self.set_surrogate_and_datasets()
return none_throws(self._surrogate)

@property
def datasets(self) -> list[SupervisedDataset]:
if self._datasets is None:
self.set_surrogate_and_datasets()
return none_throws(self._datasets)

def get_noise_stds(self) -> dict[str, float]:
noise_std = self.noise_stds
if isinstance(noise_std, float):
return {name: noise_std for name in self.outcome_names}
return noise_std

# pyre-fixme[14]: Inconsistent override
def get_Y_true(self, params: Mapping[str, float | int]) -> Tensor:
# We're ignoring the uncertainty predictions of the surrogate model here and
# use the mean predictions as the outcomes (before potentially adding noise)
means, _ = self.surrogate.predict(
# pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict
observation_features=[ObservationFeatures(params)]
)
means = [means[name][0] for name in self.outcome_names]
return torch.tensor(
means,
device=self.surrogate.device,
dtype=self.surrogate.dtype,
)

def run(self, trial: BaseTrial) -> dict[str, Any]:
"""Run the trial by evaluating its parameterization(s) on the surrogate model.

Note: This also sets the status of the trial to COMPLETED.

Args:
trial: The trial to evaluate.

Returns:
A dictionary with the following keys:
- outcome_names: The names of the metrics being evaluated.
- Ys: A dict mapping arm names to lists of corresponding outcomes,
where the order of the outcomes is the same as in `outcome_names`.
- Ystds: A dict mapping arm names to lists of corresponding outcome
noise standard deviations (possibly nan if the noise level is
unobserved), where the order of the outcomes is the same as in
`outcome_names`.
- Ys_true: A dict mapping arm names to lists of corresponding ground
truth outcomes, where the order of the outcomes is the same as
in `outcome_names`.
"""
self.statuses[trial.index] = TrialStatus.COMPLETED
run_metadata = super().run(trial=trial)
run_metadata["outcome_names"] = self.outcome_names
return run_metadata

@property
def is_noiseless(self) -> bool:
if self.noise_stds is None:
return True
if isinstance(self.noise_stds, float):
return self.noise_stds == 0.0
return all(
std == 0.0 for std in assert_is_instance(self.noise_stds, dict).values()
)

@equality_typechecker
def __eq__(self, other: Base) -> bool:
if type(other) is not type(self):
return False

# Don't check surrogate, datasets, or callable
return (
(self.name == other.name)
and (self.outcome_names == other.outcome_names)
and (self.noise_stds == other.noise_stds)
# pyre-fixme[16]: `SurrogateRunner` has no attribute `search_space_digest`.
and (self.search_space_digest == other.search_space_digest)
)
90 changes: 2 additions & 88 deletions ax/benchmark/tests/runners/test_surrogate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@
from unittest.mock import MagicMock, patch

import torch
from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.benchmark.runners.surrogate import SurrogateTestFunction
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import (
get_soo_surrogate_legacy,
get_soo_surrogate_test_function,
)
from ax.utils.testing.benchmark_stubs import get_soo_surrogate_test_function


class TestSurrogateTestFunction(TestCase):
Expand Down Expand Up @@ -84,84 +79,3 @@ def _construct_test_function(name: str) -> SurrogateTestFunction:
self.assertEqual(runner_1, runner_1a)
self.assertNotEqual(runner_1, runner_2)
self.assertNotEqual(runner_1, 1)


class TestSurrogateRunner(TestCase):
def setUp(self) -> None:
super().setUp()
self.search_space = SearchSpace(
parameters=[
RangeParameter("x", ParameterType.FLOAT, 0.0, 5.0),
RangeParameter("y", ParameterType.FLOAT, 1.0, 10.0, log_scale=True),
RangeParameter("z", ParameterType.INT, 1.0, 5.0, log_scale=True),
]
)

def test_surrogate_runner(self) -> None:
# Construct a search space with log-scale parameters.
for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}):
with self.subTest(noise_std=noise_std):
surrogate = MagicMock()
mock_mean = torch.tensor([[0.1234]], dtype=torch.double)
surrogate.predict = MagicMock(return_value=(mock_mean, 0))
surrogate.device = torch.device("cpu")
surrogate.dtype = torch.double
runner = SurrogateRunner(
name="test runner",
_surrogate=surrogate,
_datasets=[],
outcome_names=["dummy_metric"],
noise_stds=noise_std,
)
self.assertEqual(runner.name, "test runner")
self.assertIs(runner.surrogate, surrogate)
self.assertEqual(runner.outcome_names, ["dummy_metric"])
self.assertEqual(runner.noise_stds, noise_std)

def test_lazy_instantiation(self) -> None:
runner = get_soo_surrogate_legacy().runner

self.assertIsNone(runner._surrogate)
self.assertIsNone(runner._datasets)

# Accessing `surrogate` sets datasets and surrogate
self.assertIsInstance(runner.surrogate, TorchModelBridge)
self.assertIsInstance(runner._surrogate, TorchModelBridge)
self.assertIsInstance(runner._datasets, list)

# Accessing `datasets` also sets datasets and surrogate
runner = get_soo_surrogate_legacy().runner
self.assertIsInstance(runner.datasets, list)
self.assertIsInstance(runner._surrogate, TorchModelBridge)
self.assertIsInstance(runner._datasets, list)

with patch.object(
runner,
"get_surrogate_and_datasets",
wraps=runner.get_surrogate_and_datasets,
) as mock_get_surrogate_and_datasets:
runner.surrogate
mock_get_surrogate_and_datasets.assert_not_called()

def test_instantiation_raises_with_missing_args(self) -> None:
with self.assertRaisesRegex(
ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and "
):
SurrogateRunner(name="test runner", outcome_names=[], noise_stds=0.0)

def test_equality(self) -> None:
def _construct_runner(name: str) -> SurrogateRunner:
return SurrogateRunner(
name=name,
_surrogate=MagicMock(),
_datasets=[],
outcome_names=["dummy_metric"],
noise_stds=0.0,
)

runner_1 = _construct_runner("test 1")
runner_2 = _construct_runner("test 2")
runner_1a = _construct_runner("test 1")
self.assertEqual(runner_1, runner_1a)
self.assertNotEqual(runner_1, runner_2)
self.assertNotEqual(runner_1, 1)
2 changes: 0 additions & 2 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,6 @@
"SumConstraint": SumConstraint,
"Surrogate": Surrogate,
"SurrogateMetric": BenchmarkMetric, # backward-compatiblity
# NOTE: SurrogateRunners -> SyntheticRunner on load due to complications
"SurrogateRunner": SyntheticRunner,
"SobolQMCNormalSampler": SobolQMCNormalSampler,
"SyntheticRunner": SyntheticRunner,
"SurrogateSpec": SurrogateSpec,
Expand Down
49 changes: 9 additions & 40 deletions ax/utils/testing/benchmark_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.benchmark.problems.surrogate import SurrogateBenchmarkProblem
from ax.benchmark.runners.botorch_test import (
ParamBasedTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction
from ax.benchmark.runners.surrogate import SurrogateTestFunction
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, Objective
from ax.core.optimization_config import (
Expand Down Expand Up @@ -129,41 +128,7 @@ def get_soo_surrogate() -> BenchmarkProblem:
)


def get_soo_surrogate_legacy() -> SurrogateBenchmarkProblem:
experiment = get_branin_experiment(with_completed_trial=True)
surrogate = TorchModelBridge(
experiment=experiment,
search_space=experiment.search_space,
model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)),
data=experiment.lookup_data(),
transforms=[],
)
runner = SurrogateRunner(
name="test",
outcome_names=["branin"],
get_surrogate_and_datasets=lambda: (surrogate, []),
)

observe_noise_sd = True
objective = Objective(
metric=BenchmarkMetric(
name="branin", lower_is_better=True, observe_noise_sd=observe_noise_sd
),
)
optimization_config = OptimizationConfig(objective=objective)

return SurrogateBenchmarkProblem(
name="test",
search_space=experiment.search_space,
optimization_config=optimization_config,
num_trials=6,
observe_noise_stds=observe_noise_sd,
optimal_value=0.0,
runner=runner,
)


def get_moo_surrogate() -> SurrogateBenchmarkProblem:
def get_moo_surrogate() -> BenchmarkProblem:
experiment = get_branin_experiment_with_multi_objective(with_completed_trial=True)
surrogate = TorchModelBridge(
experiment=experiment,
Expand All @@ -173,11 +138,15 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem:
transforms=[],
)

runner = SurrogateRunner(
outcome_names = ["branin_a", "branin_b"]
test_function = SurrogateTestFunction(
name="test",
outcome_names=["branin_a", "branin_b"],
outcome_names=outcome_names,
get_surrogate_and_datasets=lambda: (surrogate, []),
)
runner = ParamBasedTestProblemRunner(
test_problem=test_function, outcome_names=outcome_names
)
observe_noise_sd = True
optimization_config = MultiObjectiveOptimizationConfig(
objective=MultiObjective(
Expand All @@ -199,7 +168,7 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem:
],
)
)
return SurrogateBenchmarkProblem(
return BenchmarkProblem(
name="test",
search_space=experiment.search_space,
optimization_config=optimization_config,
Expand Down