Skip to content

Commit 7f880e7

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Update MTGP transforms to use new MBM_X_trans
Summary: Follow up to D66724547 to propagate the new transforms to MTGP models. Differential Revision: D66726992
1 parent 4b2c0b5 commit 7f880e7

File tree

4 files changed

+604
-600
lines changed

4 files changed

+604
-600
lines changed

ax/modelbridge/registry.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
ChoiceToNumericChoice,
3737
OrderedChoiceToIntegerRange,
3838
)
39-
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames
4039
from ax.modelbridge.transforms.derelativize import Derelativize
4140
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
4241
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
@@ -131,15 +130,6 @@
131130
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.
132131
TS_trans: list[type[Transform]] = Y_trans + [SearchSpaceToChoice]
133132

134-
# Multi-type MTGP transforms
135-
MT_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
136-
Derelativize,
137-
ConvertMetricNames,
138-
TrialAsTask,
139-
StratifiedStandardizeY,
140-
TaskChoiceToIntTaskChoice,
141-
]
142-
143133
# Single-type MTGP transforms
144134
ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
145135
Derelativize,
@@ -148,9 +138,9 @@
148138
TaskChoiceToIntTaskChoice,
149139
]
150140

151-
# Single-type MTGP transforms
152-
Specified_Task_ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
141+
MBM_MTGP_trans: list[type[Transform]] = MBM_X_trans + [
153142
Derelativize,
143+
TrialAsTask,
154144
StratifiedStandardizeY,
155145
TaskChoiceToIntTaskChoice,
156146
]
@@ -218,7 +208,7 @@ class ModelSetup(NamedTuple):
218208
"ST_MTGP": ModelSetup(
219209
bridge_class=TorchModelBridge,
220210
model_class=ModularBoTorchModel,
221-
transforms=ST_MTGP_trans,
211+
transforms=MBM_MTGP_trans,
222212
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
223213
),
224214
"BO_MIXED": ModelSetup(
@@ -241,7 +231,7 @@ class ModelSetup(NamedTuple):
241231
"SAAS_MTGP": ModelSetup(
242232
bridge_class=TorchModelBridge,
243233
model_class=ModularBoTorchModel,
244-
transforms=ST_MTGP_trans,
234+
transforms=MBM_MTGP_trans,
245235
default_model_kwargs={
246236
"surrogate_spec": SurrogateSpec(
247237
botorch_model_class=SaasFullyBayesianMultiTaskGP

ax/modelbridge/tests/test_generation_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
from ax.modelbridge.registry import (
4444
_extract_model_state_after_gen,
4545
Cont_X_trans,
46+
MBM_MTGP_trans,
4647
MODEL_KEY_TO_MODEL_SETUP,
4748
Models,
48-
ST_MTGP_trans,
4949
)
5050
from ax.modelbridge.torch import TorchModelBridge
5151
from ax.modelbridge.transition_criterion import (
@@ -1106,7 +1106,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features(
11061106
model_kwargs={
11071107
# this will cause an error if the model
11081108
# doesn't get fixed features
1109-
"transforms": ST_MTGP_trans,
1109+
"transforms": MBM_MTGP_trans,
11101110
**self.step_model_kwargs,
11111111
},
11121112
num_trials=1,

ax/service/tests/scheduler_test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge
4545
from ax.modelbridge.dispatch_utils import choose_generation_strategy
4646
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
47-
from ax.modelbridge.registry import Models, ST_MTGP_trans
47+
from ax.modelbridge.registry import MBM_MTGP_trans, Models
4848
from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin
4949
from ax.runners.synthetic import SyntheticRunner
5050
from ax.service.scheduler import (
@@ -2391,7 +2391,7 @@ def test_it_works_with_multitask_models(
23912391
model_kwargs={
23922392
# this will cause and error if the model
23932393
# doesn't get fixed features
2394-
"transforms": ST_MTGP_trans,
2394+
"transforms": MBM_MTGP_trans,
23952395
"transform_configs": {
23962396
"TrialAsTask": {
23972397
"trial_level_map": {

0 commit comments

Comments
 (0)