Skip to content

Commit

Permalink
Change SurrogateBenchmarkProblem to TorchModelBridge (#2591)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2591

This diff changes our surrogate benchmark problems to use a `TorchModelBridge` instead of a `Surrogate`

**Why this change?**
Our current surrogate benchmark problems are constructed in the transformed space and thus don't have any transform. Constructing the surrogates in the transformed space doesn't play well with things like HSS and also doesn't allow us to fully utilize methods targeting discrete and mixed search spaces.

This setup also plays much better with how we normally fit models in Ax as, e.g., cross-validation is done on the model bridge. Going forward, constructing a surrogate benchmark problem will end up looking something like this:
1.  Load data
2. Fit an Ax modelbridge
3. Cross-validate the modelbridge to make sure the model fit quality looks good
4. Save the components needed to reconstruct the modelbridge (e.g., state dict)

Reviewed By: saitcakmak

Differential Revision: D59883218

fbshipit-source-id: cd9032549e17a2e0926f7fa7d4c2f795d79e2b28
  • Loading branch information
David Eriksson authored and facebook-github-bot committed Jul 23, 2024
1 parent be8e1f7 commit c31d5fd
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 105 deletions.
8 changes: 4 additions & 4 deletions ax/benchmark/problems/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from ax.core.runner import Runner
from ax.core.search_space import SearchSpace
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.typeutils import checked_cast, not_none
Expand All @@ -42,7 +42,7 @@ def __init__(
observe_noise_stds: Union[bool, Dict[str, bool]] = False,
noise_stds: Union[float, Dict[str, float]] = 0.0,
get_surrogate_and_datasets: Optional[
Callable[[], Tuple[Surrogate, List[SupervisedDataset]]]
Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]]
] = None,
tracking_metrics: Optional[List[BenchmarkMetricBase]] = None,
_runner: Optional[Runner] = None,
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(
observe_noise_stds: Union[bool, Dict[str, bool]] = False,
noise_stds: Union[float, Dict[str, float]] = 0.0,
get_surrogate_and_datasets: Optional[
Callable[[], Tuple[Surrogate, List[SupervisedDataset]]]
Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]]
] = None,
tracking_metrics: Optional[List[BenchmarkMetricBase]] = None,
_runner: Optional[Runner] = None,
Expand Down Expand Up @@ -210,7 +210,7 @@ def __init__(
observe_noise_stds: Union[bool, Dict[str, bool]] = False,
noise_stds: Union[float, Dict[str, float]] = 0.0,
get_surrogate_and_datasets: Optional[
Callable[[], Tuple[Surrogate, List[SupervisedDataset]]]
Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]]
] = None,
tracking_metrics: Optional[List[BenchmarkMetricBase]] = None,
_runner: Optional[Runner] = None,
Expand Down
54 changes: 11 additions & 43 deletions ax/benchmark/runners/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,16 @@
# pyre-strict

import warnings
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Union

import torch
from ax.benchmark.runners.base import BenchmarkRunner
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.observation import ObservationFeatures
from ax.core.parameter import RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TParameterization
from ax.modelbridge.transforms.int_to_float import IntToFloat
from ax.modelbridge.transforms.log import Log
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from ax.utils.common.typeutils import not_none
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor

Expand All @@ -29,7 +24,7 @@ class SurrogateRunner(BenchmarkRunner):
def __init__(
self,
name: str,
surrogate: Surrogate,
surrogate: TorchModelBridge,
datasets: List[SupervisedDataset],
search_space: SearchSpace,
outcome_names: List[str],
Expand Down Expand Up @@ -59,52 +54,25 @@ def __init__(
self.noise_stds = noise_stds
self.statuses: Dict[int, TrialStatus] = {}

# If there are log scale parameters, these need to be transformed.
if any(
isinstance(p, RangeParameter) and p.log_scale
for p in search_space.parameters.values()
):
int_to_float_tf = IntToFloat(search_space=search_space)
log_tf = Log(
search_space=int_to_float_tf.transform_search_space(
search_space.clone()
)
)
self.transforms: Optional[Tuple[IntToFloat, Log]] = (
int_to_float_tf,
log_tf,
)
else:
self.transforms = None

@property
def outcome_names(self) -> List[str]:
return self._outcome_names

def _get_transformed_parameters(
self, parameters: TParameterization
) -> TParameterization:
if self.transforms is None:
return parameters

obs_ft = ObservationFeatures(parameters=parameters)
for t in not_none(self.transforms):
obs_ft = t.transform_observation_features([obs_ft])[0]
return obs_ft.parameters

def get_noise_stds(self) -> Union[None, float, Dict[str, float]]:
return self.noise_stds

def get_Y_true(self, arm: Arm) -> Tensor:
X = torch.tensor(
[list(self._get_transformed_parameters(arm.parameters).values())],
# We're ignoring the uncertainty predictions of the surrogate model here and
# use the mean predictions as the outcomes (before potentially adding noise)
means, _ = self.surrogate.predict(
observation_features=[ObservationFeatures(arm.parameters)]
)
means = [means[name][0] for name in self.outcome_names]
return torch.tensor(
means,
device=self.surrogate.device,
dtype=self.surrogate.dtype,
)
# We're ignoring the uncertainty predictions of the surrogate model here and
# use the mean predictions as the outcomes (before potentially adding noise)
means, _ = self.surrogate.predict(X=X)
return means.squeeze(0)

def run(self, trial: BaseTrial) -> Dict[str, Any]:
"""Run the trial by evaluating its parameterization(s) on the surrogate model.
Expand Down
40 changes: 0 additions & 40 deletions ax/benchmark/tests/runners/test_surrogate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@

import torch
from ax.benchmark.problems.surrogate import SurrogateRunner
from ax.core.arm import Arm
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.trial import Trial
from ax.modelbridge.transforms.int_to_float import IntToFloat
from ax.modelbridge.transforms.log import Log
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast, not_none


class TestSurrogateRunner(TestCase):
Expand Down Expand Up @@ -48,38 +43,3 @@ def test_surrogate_runner(self) -> None:
self.assertIs(runner.surrogate, surrogate)
self.assertEqual(runner.outcome_names, ["dummy_metric"])
self.assertEqual(runner.noise_stds, noise_std)

# Check that the transforms are set up correctly.
transforms = not_none(runner.transforms)
self.assertEqual(len(transforms), 2)
self.assertIsInstance(transforms[0], IntToFloat)
self.assertIsInstance(transforms[1], Log)
self.assertEqual(
checked_cast(IntToFloat, transforms[0]).transform_parameters, {"z"}
)
self.assertEqual(
checked_cast(Log, transforms[1]).transform_parameters, {"y", "z"}
)
# Check that evaluation works correctly with the transformed parameters.
trial = Trial(experiment=MagicMock())
trial.add_arm(Arm({"x": 2.5, "y": 10.0, "z": 1.0}, name="0_0"))
run_output = runner.run(trial)
self.assertEqual(run_output["outcome_names"], ["dummy_metric"])
self.assertEqual(run_output["Ys_true"]["0_0"], [0.1234])
self.assertEqual(
run_output["Ystds"]["0_0"],
[
(
noise_std
if not isinstance(noise_std, dict)
else noise_std["dummy_metric"]
)
],
)
surrogate.predict.assert_called_once()
X = surrogate.predict.call_args[1]["X"]
self.assertTrue(
torch.allclose(
X, torch.tensor([[2.5, 1.0, 0.0]], dtype=torch.double)
)
)
10 changes: 1 addition & 9 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ax.benchmark.problems.registry import get_problem
from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.modelbridge_utils import extract_search_space_digest
from ax.modelbridge.registry import Models
from ax.service.utils.scheduler_options import SchedulerOptions
from ax.storage.json_store.load import load_experiment
Expand All @@ -44,7 +43,7 @@
get_sobol_benchmark_method,
get_soo_surrogate,
)
from ax.utils.testing.core_stubs import get_dataset, get_experiment
from ax.utils.testing.core_stubs import get_experiment
from ax.utils.testing.mock import fast_botorch_optimize
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from botorch.acquisition.multi_objective.monte_carlo import (
Expand Down Expand Up @@ -302,13 +301,6 @@ def test_replication_sobol_surrogate(self) -> None:
]:
with self.subTest(name, problem=problem):
surrogate, datasets = not_none(problem.get_surrogate_and_datasets)()
surrogate.fit(
[get_dataset()],
search_space_digest=extract_search_space_digest(
problem.search_space,
param_names=[*problem.search_space.parameters.keys()],
),
)
res = benchmark_replication(problem=problem, method=method, seed=0)

self.assertEqual(
Expand Down
42 changes: 33 additions & 9 deletions ax/utils/testing/benchmark_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,21 @@
SOOSurrogateBenchmarkProblem,
)
from ax.core.experiment import Experiment
from ax.core.optimization_config import (
MultiObjectiveOptimizationConfig,
OptimizationConfig,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.modelbridge.torch import TorchModelBridge
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.service.scheduler import SchedulerOptions
from ax.utils.common.constants import Keys
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.core_stubs import (
get_branin_multi_objective_optimization_config,
get_branin_optimization_config,
get_branin_search_space,
get_branin_experiment,
get_branin_experiment_with_multi_objective,
)
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.models.gp_regression import SingleTaskGP
Expand Down Expand Up @@ -96,11 +102,20 @@ def get_sobol_benchmark_method() -> BenchmarkMethod:


def get_soo_surrogate() -> SOOSurrogateBenchmarkProblem:
surrogate = Surrogate(botorch_model_class=SingleTaskGP)
experiment = get_branin_experiment(with_completed_trial=True)
surrogate = TorchModelBridge(
experiment=experiment,
search_space=experiment.search_space,
model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)),
data=experiment.lookup_data(),
transforms=[],
)
return SOOSurrogateBenchmarkProblem(
name="test",
search_space=get_branin_search_space(),
optimization_config=get_branin_optimization_config(),
search_space=experiment.search_space,
optimization_config=checked_cast(
OptimizationConfig, experiment.optimization_config
),
num_trials=6,
outcome_names=["branin"],
observe_noise_stds=True,
Expand All @@ -110,11 +125,20 @@ def get_soo_surrogate() -> SOOSurrogateBenchmarkProblem:


def get_moo_surrogate() -> MOOSurrogateBenchmarkProblem:
surrogate = Surrogate(botorch_model_class=SingleTaskGP)
experiment = get_branin_experiment_with_multi_objective(with_completed_trial=True)
surrogate = TorchModelBridge(
experiment=experiment,
search_space=experiment.search_space,
model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)),
data=experiment.lookup_data(),
transforms=[],
)
return MOOSurrogateBenchmarkProblem(
name="test",
search_space=get_branin_search_space(),
optimization_config=get_branin_multi_objective_optimization_config(),
search_space=experiment.search_space,
optimization_config=checked_cast(
MultiObjectiveOptimizationConfig, experiment.optimization_config
),
num_trials=10,
outcome_names=["branin_a", "branin_b"],
observe_noise_stds=True,
Expand Down
14 changes: 14 additions & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,8 @@ def get_branin_experiment_with_multi_objective(
with_status_quo: bool = False,
with_fidelity_parameter: bool = False,
num_objectives: int = 2,
with_trial: bool = False,
with_completed_trial: bool = False,
) -> Experiment:
exp = Experiment(
name="branin_test_experiment",
Expand Down Expand Up @@ -640,6 +642,18 @@ def get_branin_experiment_with_multi_objective(
sobol_run
)

if with_trial or with_completed_trial:
sobol_generator = get_sobol(search_space=exp.search_space)
sobol_run = sobol_generator.gen(n=1)
trial = exp.new_trial(generator_run=sobol_run)

if with_completed_trial:
trial.mark_running(no_runner_required=True)
exp.attach_data(
get_branin_data_multi_objective(trial_indices=[trial.index])
) # Add data for one trial
trial.mark_completed()

return exp


Expand Down

0 comments on commit c31d5fd

Please sign in to comment.