From c31d5fd75f115c9423e66aed687adc697328a4c7 Mon Sep 17 00:00:00 2001 From: David Eriksson Date: Mon, 22 Jul 2024 17:40:44 -0700 Subject: [PATCH] Change SurrogateBenchmarkProblem to TorchModelBridge (#2591) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2591 This diff changes our surrogate benchmark problems to use a `TorchModelBridge` instead of a `Surrogate` **Why this change?** Our current surrogate benchmark problems are constructed in the transformed space and thus don't have any transform. Constructing the surrogates in the transformed space doesn't play well with things like HSS and also doesn't allow us to fully utilize methods targeting discrete and mixed search spaces. This setup also plays much better with how we normally fit models in Ax as, e.g., cross-validation is done on the model bridge. Going forward, constructing a surrogate benchmark problem will end up looking something like this: 1. Load data 2. Fit an Ax modelbridge 3. Cross-validate the modelbridge to make sure the model fit quality looks good 4. Save the components needed to reconstruct the modelbridge (e.g., state dict) Reviewed By: saitcakmak Differential Revision: D59883218 fbshipit-source-id: cd9032549e17a2e0926f7fa7d4c2f795d79e2b28 --- ax/benchmark/problems/surrogate.py | 8 +-- ax/benchmark/runners/surrogate.py | 54 ++++--------------- .../tests/runners/test_surrogate_runner.py | 40 -------------- ax/benchmark/tests/test_benchmark.py | 10 +--- ax/utils/testing/benchmark_stubs.py | 42 +++++++++++---- ax/utils/testing/core_stubs.py | 14 +++++ 6 files changed, 63 insertions(+), 105 deletions(-) diff --git a/ax/benchmark/problems/surrogate.py b/ax/benchmark/problems/surrogate.py index daf841901c9..3b07a57d81f 100644 --- a/ax/benchmark/problems/surrogate.py +++ b/ax/benchmark/problems/surrogate.py @@ -16,7 +16,7 @@ ) from ax.core.runner import Runner from ax.core.search_space import SearchSpace -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker from ax.utils.common.typeutils import checked_cast, not_none @@ -42,7 +42,7 @@ def __init__( observe_noise_stds: Union[bool, Dict[str, bool]] = False, noise_stds: Union[float, Dict[str, float]] = 0.0, get_surrogate_and_datasets: Optional[ - Callable[[], Tuple[Surrogate, List[SupervisedDataset]]] + Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]] ] = None, tracking_metrics: Optional[List[BenchmarkMetricBase]] = None, _runner: Optional[Runner] = None, @@ -163,7 +163,7 @@ def __init__( observe_noise_stds: Union[bool, Dict[str, bool]] = False, noise_stds: Union[float, Dict[str, float]] = 0.0, get_surrogate_and_datasets: Optional[ - Callable[[], Tuple[Surrogate, List[SupervisedDataset]]] + Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]] ] = None, tracking_metrics: Optional[List[BenchmarkMetricBase]] = None, _runner: Optional[Runner] = None, @@ -210,7 +210,7 @@ def __init__( observe_noise_stds: Union[bool, Dict[str, bool]] = False, noise_stds: Union[float, Dict[str, float]] = 0.0, get_surrogate_and_datasets: Optional[ - Callable[[], Tuple[Surrogate, List[SupervisedDataset]]] + Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]] ] = None, tracking_metrics: Optional[List[BenchmarkMetricBase]] = None, _runner: Optional[Runner] = None, diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index 8f0303a0a39..f64a5d1dd15 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -6,21 +6,16 @@ # pyre-strict import warnings -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Union import torch from ax.benchmark.runners.base import BenchmarkRunner from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.observation import ObservationFeatures -from ax.core.parameter import RangeParameter from ax.core.search_space import SearchSpace -from ax.core.types import TParameterization -from ax.modelbridge.transforms.int_to_float import IntToFloat -from ax.modelbridge.transforms.log import Log -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry -from ax.utils.common.typeutils import not_none from botorch.utils.datasets import SupervisedDataset from torch import Tensor @@ -29,7 +24,7 @@ class SurrogateRunner(BenchmarkRunner): def __init__( self, name: str, - surrogate: Surrogate, + surrogate: TorchModelBridge, datasets: List[SupervisedDataset], search_space: SearchSpace, outcome_names: List[str], @@ -59,52 +54,25 @@ def __init__( self.noise_stds = noise_stds self.statuses: Dict[int, TrialStatus] = {} - # If there are log scale parameters, these need to be transformed. - if any( - isinstance(p, RangeParameter) and p.log_scale - for p in search_space.parameters.values() - ): - int_to_float_tf = IntToFloat(search_space=search_space) - log_tf = Log( - search_space=int_to_float_tf.transform_search_space( - search_space.clone() - ) - ) - self.transforms: Optional[Tuple[IntToFloat, Log]] = ( - int_to_float_tf, - log_tf, - ) - else: - self.transforms = None - @property def outcome_names(self) -> List[str]: return self._outcome_names - def _get_transformed_parameters( - self, parameters: TParameterization - ) -> TParameterization: - if self.transforms is None: - return parameters - - obs_ft = ObservationFeatures(parameters=parameters) - for t in not_none(self.transforms): - obs_ft = t.transform_observation_features([obs_ft])[0] - return obs_ft.parameters - def get_noise_stds(self) -> Union[None, float, Dict[str, float]]: return self.noise_stds def get_Y_true(self, arm: Arm) -> Tensor: - X = torch.tensor( - [list(self._get_transformed_parameters(arm.parameters).values())], + # We're ignoring the uncertainty predictions of the surrogate model here and + # use the mean predictions as the outcomes (before potentially adding noise) + means, _ = self.surrogate.predict( + observation_features=[ObservationFeatures(arm.parameters)] + ) + means = [means[name][0] for name in self.outcome_names] + return torch.tensor( + means, device=self.surrogate.device, dtype=self.surrogate.dtype, ) - # We're ignoring the uncertainty predictions of the surrogate model here and - # use the mean predictions as the outcomes (before potentially adding noise) - means, _ = self.surrogate.predict(X=X) - return means.squeeze(0) def run(self, trial: BaseTrial) -> Dict[str, Any]: """Run the trial by evaluating its parameterization(s) on the surrogate model. diff --git a/ax/benchmark/tests/runners/test_surrogate_runner.py b/ax/benchmark/tests/runners/test_surrogate_runner.py index 55f0036c391..0fdf4e65154 100644 --- a/ax/benchmark/tests/runners/test_surrogate_runner.py +++ b/ax/benchmark/tests/runners/test_surrogate_runner.py @@ -9,14 +9,9 @@ import torch from ax.benchmark.problems.surrogate import SurrogateRunner -from ax.core.arm import Arm from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace -from ax.core.trial import Trial -from ax.modelbridge.transforms.int_to_float import IntToFloat -from ax.modelbridge.transforms.log import Log from ax.utils.common.testutils import TestCase -from ax.utils.common.typeutils import checked_cast, not_none class TestSurrogateRunner(TestCase): @@ -48,38 +43,3 @@ def test_surrogate_runner(self) -> None: self.assertIs(runner.surrogate, surrogate) self.assertEqual(runner.outcome_names, ["dummy_metric"]) self.assertEqual(runner.noise_stds, noise_std) - - # Check that the transforms are set up correctly. - transforms = not_none(runner.transforms) - self.assertEqual(len(transforms), 2) - self.assertIsInstance(transforms[0], IntToFloat) - self.assertIsInstance(transforms[1], Log) - self.assertEqual( - checked_cast(IntToFloat, transforms[0]).transform_parameters, {"z"} - ) - self.assertEqual( - checked_cast(Log, transforms[1]).transform_parameters, {"y", "z"} - ) - # Check that evaluation works correctly with the transformed parameters. - trial = Trial(experiment=MagicMock()) - trial.add_arm(Arm({"x": 2.5, "y": 10.0, "z": 1.0}, name="0_0")) - run_output = runner.run(trial) - self.assertEqual(run_output["outcome_names"], ["dummy_metric"]) - self.assertEqual(run_output["Ys_true"]["0_0"], [0.1234]) - self.assertEqual( - run_output["Ystds"]["0_0"], - [ - ( - noise_std - if not isinstance(noise_std, dict) - else noise_std["dummy_metric"] - ) - ], - ) - surrogate.predict.assert_called_once() - X = surrogate.predict.call_args[1]["X"] - self.assertTrue( - torch.allclose( - X, torch.tensor([[2.5, 1.0, 0.0]], dtype=torch.double) - ) - ) diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 7a1a8a1e202..dc24816effb 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -29,7 +29,6 @@ from ax.benchmark.problems.registry import get_problem from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.modelbridge_utils import extract_search_space_digest from ax.modelbridge.registry import Models from ax.service.utils.scheduler_options import SchedulerOptions from ax.storage.json_store.load import load_experiment @@ -44,7 +43,7 @@ get_sobol_benchmark_method, get_soo_surrogate, ) -from ax.utils.testing.core_stubs import get_dataset, get_experiment +from ax.utils.testing.core_stubs import get_experiment from ax.utils.testing.mock import fast_botorch_optimize from botorch.acquisition.logei import qLogNoisyExpectedImprovement from botorch.acquisition.multi_objective.monte_carlo import ( @@ -302,13 +301,6 @@ def test_replication_sobol_surrogate(self) -> None: ]: with self.subTest(name, problem=problem): surrogate, datasets = not_none(problem.get_surrogate_and_datasets)() - surrogate.fit( - [get_dataset()], - search_space_digest=extract_search_space_digest( - problem.search_space, - param_names=[*problem.search_space.parameters.keys()], - ), - ) res = benchmark_replication(problem=problem, method=method, seed=0) self.assertEqual( diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index fea681fccde..8aa787a596e 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -21,15 +21,21 @@ SOOSurrogateBenchmarkProblem, ) from ax.core.experiment import Experiment +from ax.core.optimization_config import ( + MultiObjectiveOptimizationConfig, + OptimizationConfig, +) from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import Models +from ax.modelbridge.torch import TorchModelBridge +from ax.models.torch.botorch_modular.model import BoTorchModel from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.service.scheduler import SchedulerOptions from ax.utils.common.constants import Keys +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.core_stubs import ( - get_branin_multi_objective_optimization_config, - get_branin_optimization_config, - get_branin_search_space, + get_branin_experiment, + get_branin_experiment_with_multi_objective, ) from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement from botorch.models.gp_regression import SingleTaskGP @@ -96,11 +102,20 @@ def get_sobol_benchmark_method() -> BenchmarkMethod: def get_soo_surrogate() -> SOOSurrogateBenchmarkProblem: - surrogate = Surrogate(botorch_model_class=SingleTaskGP) + experiment = get_branin_experiment(with_completed_trial=True) + surrogate = TorchModelBridge( + experiment=experiment, + search_space=experiment.search_space, + model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), + data=experiment.lookup_data(), + transforms=[], + ) return SOOSurrogateBenchmarkProblem( name="test", - search_space=get_branin_search_space(), - optimization_config=get_branin_optimization_config(), + search_space=experiment.search_space, + optimization_config=checked_cast( + OptimizationConfig, experiment.optimization_config + ), num_trials=6, outcome_names=["branin"], observe_noise_stds=True, @@ -110,11 +125,20 @@ def get_soo_surrogate() -> SOOSurrogateBenchmarkProblem: def get_moo_surrogate() -> MOOSurrogateBenchmarkProblem: - surrogate = Surrogate(botorch_model_class=SingleTaskGP) + experiment = get_branin_experiment_with_multi_objective(with_completed_trial=True) + surrogate = TorchModelBridge( + experiment=experiment, + search_space=experiment.search_space, + model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), + data=experiment.lookup_data(), + transforms=[], + ) return MOOSurrogateBenchmarkProblem( name="test", - search_space=get_branin_search_space(), - optimization_config=get_branin_multi_objective_optimization_config(), + search_space=experiment.search_space, + optimization_config=checked_cast( + MultiObjectiveOptimizationConfig, experiment.optimization_config + ), num_trials=10, outcome_names=["branin_a", "branin_b"], observe_noise_stds=True, diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 7277cb2c38e..192bb62e94f 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -611,6 +611,8 @@ def get_branin_experiment_with_multi_objective( with_status_quo: bool = False, with_fidelity_parameter: bool = False, num_objectives: int = 2, + with_trial: bool = False, + with_completed_trial: bool = False, ) -> Experiment: exp = Experiment( name="branin_test_experiment", @@ -640,6 +642,18 @@ def get_branin_experiment_with_multi_objective( sobol_run ) + if with_trial or with_completed_trial: + sobol_generator = get_sobol(search_space=exp.search_space) + sobol_run = sobol_generator.gen(n=1) + trial = exp.new_trial(generator_run=sobol_run) + + if with_completed_trial: + trial.mark_running(no_runner_required=True) + exp.attach_data( + get_branin_data_multi_objective(trial_indices=[trial.index]) + ) # Add data for one trial + trial.mark_completed() + return exp