Skip to content

Commit d006021

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Update MTGP transforms to use new MBM_X_trans
Summary: Follow up to D66724547 to propagate the new transforms to MTGP models. Differential Revision: D66726992
1 parent 3847400 commit d006021

File tree

4 files changed

+23
-20
lines changed

4 files changed

+23
-20
lines changed

ax/modelbridge/registry.py

+4-14
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
ChoiceToNumericChoice,
3737
OrderedChoiceToIntegerRange,
3838
)
39-
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames
4039
from ax.modelbridge.transforms.derelativize import Derelativize
4140
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
4241
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
@@ -131,15 +130,6 @@
131130
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.
132131
TS_trans: list[type[Transform]] = Y_trans + [SearchSpaceToChoice]
133132

134-
# Multi-type MTGP transforms
135-
MT_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
136-
Derelativize,
137-
ConvertMetricNames,
138-
TrialAsTask,
139-
StratifiedStandardizeY,
140-
TaskChoiceToIntTaskChoice,
141-
]
142-
143133
# Single-type MTGP transforms
144134
ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
145135
Derelativize,
@@ -148,9 +138,9 @@
148138
TaskChoiceToIntTaskChoice,
149139
]
150140

151-
# Single-type MTGP transforms
152-
Specified_Task_ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
141+
MBM_MTGP_trans: list[type[Transform]] = MBM_X_trans + [
153142
Derelativize,
143+
TrialAsTask,
154144
StratifiedStandardizeY,
155145
TaskChoiceToIntTaskChoice,
156146
]
@@ -218,7 +208,7 @@ class ModelSetup(NamedTuple):
218208
"ST_MTGP": ModelSetup(
219209
bridge_class=TorchModelBridge,
220210
model_class=ModularBoTorchModel,
221-
transforms=ST_MTGP_trans,
211+
transforms=MBM_MTGP_trans,
222212
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
223213
),
224214
"BO_MIXED": ModelSetup(
@@ -241,7 +231,7 @@ class ModelSetup(NamedTuple):
241231
"SAAS_MTGP": ModelSetup(
242232
bridge_class=TorchModelBridge,
243233
model_class=ModularBoTorchModel,
244-
transforms=ST_MTGP_trans,
234+
transforms=MBM_MTGP_trans,
245235
default_model_kwargs={
246236
"surrogate_spec": SurrogateSpec(
247237
botorch_model_class=SaasFullyBayesianMultiTaskGP

ax/modelbridge/tests/test_generation_strategy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
from ax.modelbridge.registry import (
4444
_extract_model_state_after_gen,
4545
Cont_X_trans,
46+
MBM_MTGP_trans,
4647
MODEL_KEY_TO_MODEL_SETUP,
4748
Models,
48-
ST_MTGP_trans,
4949
)
5050
from ax.modelbridge.torch import TorchModelBridge
5151
from ax.modelbridge.transition_criterion import (
@@ -1106,7 +1106,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features(
11061106
model_kwargs={
11071107
# this will cause an error if the model
11081108
# doesn't get fixed features
1109-
"transforms": ST_MTGP_trans,
1109+
"transforms": MBM_MTGP_trans,
11101110
**self.step_model_kwargs,
11111111
},
11121112
num_trials=1,

ax/service/tests/scheduler_test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge
4545
from ax.modelbridge.dispatch_utils import choose_generation_strategy
4646
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
47-
from ax.modelbridge.registry import Models, ST_MTGP_trans
47+
from ax.modelbridge.registry import MBM_MTGP_trans, Models
4848
from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin
4949
from ax.runners.synthetic import SyntheticRunner
5050
from ax.service.scheduler import (
@@ -2391,7 +2391,7 @@ def test_it_works_with_multitask_models(
23912391
model_kwargs={
23922392
# this will cause and error if the model
23932393
# doesn't get fixed features
2394-
"transforms": ST_MTGP_trans,
2394+
"transforms": MBM_MTGP_trans,
23952395
"transform_configs": {
23962396
"TrialAsTask": {
23972397
"trial_level_map": {

tutorials/multi_task.ipynb

+15-2
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,28 @@
5656
"from ax.core.search_space import SearchSpace\n",
5757
"from ax.metrics.hartmann6 import Hartmann6Metric\n",
5858
"from ax.modelbridge.factory import get_sobol\n",
59-
"from ax.modelbridge.registry import Models, MT_MTGP_trans, ST_MTGP_trans\n",
59+
"from ax.modelbridge.registry import Models, MBM_X_trans, ST_MTGP_trans\n",
6060
"from ax.modelbridge.torch import TorchModelBridge\n",
6161
"from ax.modelbridge.transforms.convert_metric_names import tconfig_from_mt_experiment\n",
62+
"from ax.modelbridge.transforms.derelativize import Derelativize\n",
63+
"from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames\n",
64+
"from ax.modelbridge.transforms.trial_as_task import TrialAsTask\n",
65+
"from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY\n",
66+
"from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice\n",
6267
"from ax.plot.diagnostic import interact_batch_comparison\n",
6368
"from ax.runners.synthetic import SyntheticRunner\n",
6469
"from ax.utils.common.typeutils import checked_cast\n",
6570
"from ax.utils.notebook.plotting import init_notebook_plotting, render\n",
6671
"\n",
67-
"init_notebook_plotting()"
72+
"init_notebook_plotting()\n",
73+
"\n",
74+
"MT_MTGP_trans = MBM_X_trans + [\n",
75+
" Derelativize,\n",
76+
" ConvertMetricNames,\n",
77+
" TrialAsTask,\n",
78+
" StratifiedStandardizeY,\n",
79+
" TaskChoiceToIntTaskChoice,\n",
80+
"]"
6881
]
6982
},
7083
{

0 commit comments

Comments
 (0)