Skip to content

Commit 64a8119

Browse files
Daniel Cohenfacebook-github-bot
Daniel Cohen
authored andcommitted
Combine experiment and gs fields into AnalysisBase (#3137)
Summary: AnalysisBase will have optional `_experiment` and `_generation_strategy` fields, with getter and setter properties, as well as the `standard_generation_strategy` prop from `Scheduler`. `Scheduler` and `AxClient` will inherit these from it. The naming is less than optimal with AnalysisBase holding the experiment and GS. This is otherwise a no-op change designed to reduce pyre errors. Differential Revision: D66712036
1 parent effed6d commit 64a8119

File tree

5 files changed

+100
-80
lines changed

5 files changed

+100
-80
lines changed

ax/modelbridge/tests/test_prediction_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_predict_at_point(self) -> None:
2828

2929
observation_features = ObservationFeatures(parameters={"x1": 0.3, "x2": 0.5})
3030
y_hat, se_hat = predict_at_point(
31-
model=none_throws(ax_client.generation_strategy.model),
31+
model=none_throws(ax_client.standard_generation_strategy.model),
3232
obsf=observation_features,
3333
metric_names={"test_metric1"},
3434
)
@@ -37,7 +37,7 @@ def test_predict_at_point(self) -> None:
3737
self.assertEqual(len(se_hat), 1)
3838

3939
y_hat, se_hat = predict_at_point(
40-
model=none_throws(ax_client.generation_strategy.model),
40+
model=none_throws(ax_client.standard_generation_strategy.model),
4141
obsf=observation_features,
4242
metric_names={"test_metric1", "test_metric2", "test_metric:agg"},
4343
scalarized_metric_config=[
@@ -51,7 +51,7 @@ def test_predict_at_point(self) -> None:
5151
self.assertEqual(len(se_hat), 3)
5252

5353
y_hat, se_hat = predict_at_point(
54-
model=none_throws(ax_client.generation_strategy.model),
54+
model=none_throws(ax_client.standard_generation_strategy.model),
5555
obsf=observation_features,
5656
metric_names={"test_metric1"},
5757
scalarized_metric_config=[
@@ -75,7 +75,7 @@ def test_predict_by_features(self) -> None:
7575
20: ObservationFeatures(parameters={"x1": 0.8, "x2": 0.5}),
7676
}
7777
predictions_map = predict_by_features(
78-
model=none_throws(ax_client.generation_strategy.model),
78+
model=none_throws(ax_client.standard_generation_strategy.model),
7979
label_to_feature_dict=observation_features_dict,
8080
metric_names={"test_metric1"},
8181
)

ax/service/ax_client.py

+19-37
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,8 @@ def get_next_trial(
599599
# TODO[T79183560]: Ensure correct handling of generator run when using
600600
# foreign keys.
601601
self._update_generation_strategy_in_db_if_possible(
602-
generation_strategy=self.generation_strategy,
603-
new_generator_runs=[self.generation_strategy._generator_runs[-1]],
602+
generation_strategy=self.standard_generation_strategy,
603+
new_generator_runs=[self.standard_generation_strategy._generator_runs[-1]],
604604
)
605605
return none_throws(trial.arm).parameters, trial.index
606606

@@ -625,7 +625,7 @@ def get_current_trial_generation_limit(self) -> tuple[int, bool]:
625625
if self.generation_strategy._experiment is None:
626626
self.generation_strategy.experiment = self.experiment
627627

628-
return self.generation_strategy.current_generator_run_limit()
628+
return self.standard_generation_strategy.current_generator_run_limit()
629629

630630
def get_next_trials(
631631
self,
@@ -950,7 +950,7 @@ def get_max_parallelism(self) -> list[tuple[int, int]]:
950950
Mapping of form {num_trials -> max_parallelism_setting}.
951951
"""
952952
parallelism_settings = []
953-
for step in self.generation_strategy._steps:
953+
for step in self.standard_generation_strategy._steps:
954954
parallelism_settings.append(
955955
(step.num_trials, step.max_parallelism or step.num_trials)
956956
)
@@ -1071,15 +1071,15 @@ def get_contour_plot(
10711071
raise ValueError(
10721072
f'Metric "{metric_name}" is not associated with this optimization.'
10731073
)
1074-
if self.generation_strategy.model is not None:
1074+
if self.standard_generation_strategy.model is not None:
10751075
try:
10761076
logger.info(
10771077
f"Retrieving contour plot with parameter '{param_x}' on X-axis "
10781078
f"and '{param_y}' on Y-axis, for metric '{metric_name}'. "
10791079
"Remaining parameters are affixed to the middle of their range."
10801080
)
10811081
return plot_contour(
1082-
model=none_throws(self.generation_strategy.model),
1082+
model=none_throws(self.standard_generation_strategy.model),
10831083
param_x=param_x,
10841084
param_y=param_y,
10851085
metric_name=metric_name,
@@ -1089,8 +1089,8 @@ def get_contour_plot(
10891089
# Some models don't implement '_predict', which is needed
10901090
# for the contour plots.
10911091
logger.info(
1092-
f"Model {self.generation_strategy.model} does not implement "
1093-
"`predict`, so it cannot be used to generate a response "
1092+
f"Model {self.standard_generation_strategy.model} does not "
1093+
"implement `predict`, so it cannot be used to generate a response "
10941094
"surface plot."
10951095
)
10961096
raise UnsupportedPlotError(
@@ -1112,14 +1112,14 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig:
11121112
"""
11131113
if not self.experiment.trials:
11141114
raise ValueError("Cannot generate plot as there are no trials.")
1115-
cur_model = self.generation_strategy.model
1115+
cur_model = self.standard_generation_strategy.model
11161116
if cur_model is not None:
11171117
try:
11181118
return plot_feature_importance_by_feature(cur_model, relative=relative)
11191119
except NotImplementedError:
11201120
logger.info(
1121-
f"Model {self.generation_strategy.model} does not implement "
1122-
"`feature_importances`, so it cannot be used to generate "
1121+
f"Model {self.standard_generation_strategy.model} does not "
1122+
"implement `feature_importances`, so it cannot be used to generate "
11231123
"this plot. Only certain models, implement feature importances."
11241124
)
11251125

@@ -1247,7 +1247,8 @@ def get_model_predictions(
12471247
else set(none_throws(self.experiment.metrics).keys())
12481248
)
12491249
model = none_throws(
1250-
self.generation_strategy.model, "No model has been instantiated yet."
1250+
self.standard_generation_strategy.model,
1251+
"No model has been instantiated yet.",
12511252
)
12521253

12531254
# Construct a dictionary that maps from a label to an
@@ -1306,8 +1307,8 @@ def fit_model(self) -> None:
13061307
"At least one trial must be completed with data to fit a model."
13071308
)
13081309
# Check if we should transition before generating the next candidate.
1309-
self.generation_strategy._maybe_transition_to_next_node()
1310-
self.generation_strategy._fit_current_model(data=None)
1310+
self.standard_generation_strategy._maybe_transition_to_next_node()
1311+
self.standard_generation_strategy._fit_current_model(data=None)
13111312

13121313
def verify_trial_parameterization(
13131314
self, trial_index: int, parameterization: TParameterization
@@ -1496,29 +1497,10 @@ def from_json_snapshot(
14961497

14971498
# ---------------------- Private helper methods. ---------------------
14981499

1499-
@property
1500-
def experiment(self) -> Experiment:
1501-
"""Returns the experiment set on this Ax client."""
1502-
return none_throws(
1503-
self._experiment,
1504-
(
1505-
"Experiment not set on Ax client. Must first "
1506-
"call load_experiment or create_experiment to use handler functions."
1507-
),
1508-
)
1509-
15101500
def get_trial(self, trial_index: int) -> Trial:
15111501
"""Return a trial on experiment cast as Trial"""
15121502
return checked_cast(Trial, self.experiment.trials[trial_index])
15131503

1514-
@property
1515-
def generation_strategy(self) -> GenerationStrategy:
1516-
"""Returns the generation strategy, set on this experiment."""
1517-
return none_throws(
1518-
self._generation_strategy,
1519-
"No generation strategy has been set on this optimization yet.",
1520-
)
1521-
15221504
@property
15231505
def objective(self) -> Objective:
15241506
return none_throws(self.experiment.optimization_config).objective
@@ -1586,7 +1568,7 @@ def get_best_trial(
15861568
) -> tuple[int, TParameterization, TModelPredictArm | None] | None:
15871569
return self._get_best_trial(
15881570
experiment=self.experiment,
1589-
generation_strategy=self.generation_strategy,
1571+
generation_strategy=self.standard_generation_strategy,
15901572
trial_indices=trial_indices,
15911573
use_model_predictions=use_model_predictions,
15921574
)
@@ -1600,7 +1582,7 @@ def get_pareto_optimal_parameters(
16001582
) -> dict[int, tuple[TParameterization, TModelPredictArm]]:
16011583
return self._get_pareto_optimal_parameters(
16021584
experiment=self.experiment,
1603-
generation_strategy=self.generation_strategy,
1585+
generation_strategy=self.standard_generation_strategy,
16041586
trial_indices=trial_indices,
16051587
use_model_predictions=use_model_predictions,
16061588
)
@@ -1614,7 +1596,7 @@ def get_hypervolume(
16141596
) -> float:
16151597
return BestPointMixin._get_hypervolume(
16161598
experiment=self.experiment,
1617-
generation_strategy=self.generation_strategy,
1599+
generation_strategy=self.standard_generation_strategy,
16181600
optimization_config=optimization_config,
16191601
trial_indices=trial_indices,
16201602
use_model_predictions=use_model_predictions,
@@ -1817,7 +1799,7 @@ def _gen_new_generator_run(
18171799
else None
18181800
)
18191801
with with_rng_seed(seed=self._random_seed):
1820-
return none_throws(self.generation_strategy).gen(
1802+
return none_throws(self.standard_generation_strategy).gen(
18211803
experiment=self.experiment,
18221804
n=n,
18231805
pending_observations=self._get_pending_observation_features(

ax/service/scheduler.py

-17
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,6 @@ class Scheduler(AnalysisBase, BestPointMixin):
162162
been saved, as otherwise experiment state could get corrupted.**
163163
"""
164164

165-
experiment: Experiment
166-
generation_strategy: GenerationStrategyInterface
167165
# pyre-fixme[24]: Generic type `LoggerAdapter` expects 1 type parameter.
168166
logger: LoggerAdapter
169167
# Mapping of form {short string identifier -> message to show in reported
@@ -491,21 +489,6 @@ def runner(self) -> Runner:
491489
)
492490
return runner
493491

494-
@property
495-
def standard_generation_strategy(self) -> GenerationStrategy:
496-
"""Used for operations in the scheduler that can only be done with
497-
and instance of ``GenerationStrategy``.
498-
"""
499-
gs = self.generation_strategy
500-
if not isinstance(gs, GenerationStrategy):
501-
raise NotImplementedError(
502-
"This functionality is only supported with instances of "
503-
"`GenerationStrategy` (one that uses `GenerationStrategy` "
504-
"class) and not yet with other types of "
505-
"`GenerationStrategyInterface`."
506-
)
507-
return gs
508-
509492
def __repr__(self) -> str:
510493
"""Short user-friendly string representation."""
511494
if not hasattr(self, "experiment"):

ax/service/tests/test_ax_client.py

+25-20
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,10 @@ def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None:
504504
"""
505505
ax_client = get_branin_optimization()
506506
self.assertEqual(
507-
[s.model for s in none_throws(ax_client.generation_strategy)._steps],
507+
[
508+
s.model
509+
for s in none_throws(ax_client.standard_generation_strategy)._steps
510+
],
508511
[Models.SOBOL, Models.BOTORCH_MODULAR],
509512
)
510513
with self.assertRaisesRegex(ValueError, ".* no trials"):
@@ -719,7 +722,7 @@ def test_default_generation_strategy_continuous_for_moo(
719722
},
720723
)
721724
self.assertEqual(
722-
[s.model for s in none_throws(ax_client.generation_strategy)._steps],
725+
[s.model for s in ax_client.standard_generation_strategy._steps],
723726
[Models.SOBOL, Models.BOTORCH_MODULAR],
724727
)
725728
with self.assertRaisesRegex(ValueError, ".* no trials"):
@@ -782,7 +785,7 @@ def test_create_experiment(self) -> None:
782785
steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]
783786
)
784787
)
785-
with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"):
788+
with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"):
786789
ax_client.experiment
787790
ax_client.create_experiment(
788791
name="test_experiment",
@@ -1019,7 +1022,7 @@ def test_create_single_objective_experiment_with_objectives_dict(self) -> None:
10191022
steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]
10201023
)
10211024
)
1022-
with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"):
1025+
with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"):
10231026
ax_client.experiment
10241027
ax_client.create_experiment(
10251028
name="test_experiment",
@@ -1080,7 +1083,7 @@ def test_create_single_objective_experiment_with_objectives_dict(self) -> None:
10801083
def test_create_experiment_with_metric_definitions(self) -> None:
10811084
"""Test basic experiment creation."""
10821085
ax_client = AxClient()
1083-
with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"):
1086+
with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"):
10841087
ax_client.experiment
10851088

10861089
metric_definitions = {
@@ -1347,7 +1350,7 @@ def test_create_moo_experiment(self) -> None:
13471350
steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]
13481351
)
13491352
)
1350-
with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"):
1353+
with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"):
13511354
ax_client.experiment
13521355
ax_client.create_experiment(
13531356
name="test_experiment",
@@ -1581,10 +1584,9 @@ def test_keep_generating_without_data(self) -> None:
15811584
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
15821585
],
15831586
)
1584-
self.assertFalse(
1585-
ax_client.generation_strategy._steps[0].enforce_num_trials, False
1586-
)
1587-
self.assertFalse(ax_client.generation_strategy._steps[1].max_parallelism, None)
1587+
gs = ax_client.standard_generation_strategy
1588+
self.assertFalse(gs._steps[0].enforce_num_trials, False)
1589+
self.assertFalse(gs._steps[1].max_parallelism, None)
15881590
for _ in range(10):
15891591
parameterization, trial_index = ax_client.get_next_trial()
15901592

@@ -2100,14 +2102,14 @@ def test_sqa_storage(self) -> None:
21002102
# pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, U...
21012103
raw_data=branin(*parameters.values()),
21022104
)
2103-
gs = ax_client.generation_strategy
2105+
gs = ax_client.standard_generation_strategy
21042106
ax_client = AxClient(db_settings=db_settings)
21052107
ax_client.load_experiment_from_database("test_experiment")
21062108
# Some fields of the reloaded GS are not expected to be set (both will be
21072109
# set during next model fitting call), so we unset them on the original GS as
21082110
# well.
21092111
gs._unset_non_persistent_state_fields()
2110-
ax_client.generation_strategy._unset_non_persistent_state_fields()
2112+
ax_client.standard_generation_strategy._unset_non_persistent_state_fields()
21112113
self.assertEqual(gs, ax_client.generation_strategy)
21122114
with self.assertRaises(ValueError):
21132115
# Overwriting existing experiment.
@@ -2461,8 +2463,9 @@ def helper_test_get_pareto_optimal_points(
24612463
num_trials=20, outcome_constraints=outcome_constraints
24622464
)
24632465
ax_client.fit_model()
2466+
gs = ax_client.standard_generation_strategy
24642467
self.assertEqual(
2465-
ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key,
2468+
gs._curr.model_spec_to_gen_from.model_key,
24662469
"BoTorch",
24672470
)
24682471

@@ -2487,7 +2490,7 @@ def helper_test_get_pareto_optimal_points(
24872490
# This overwrites the `predict` call to return the original observations,
24882491
# while testing the rest of the code as if we're using predictions.
24892492
# pyre-fixme[16]: `Optional` has no attribute `model`.
2490-
model = ax_client.generation_strategy.model.model
2493+
model = ax_client.standard_generation_strategy.model.model
24912494
ys = model.surrogate.training_data[0].Y
24922495
with patch.object(
24932496
model, "predict", return_value=(ys, torch.zeros(*ys.shape, ys.shape[-1]))
@@ -2531,8 +2534,9 @@ def helper_test_get_pareto_optimal_points_from_sobol_step(
25312534
ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials(
25322535
num_trials=20, minimize=minimize, outcome_constraints=outcome_constraints
25332536
)
2537+
gs = ax_client.standard_generation_strategy
25342538
self.assertEqual(
2535-
ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key,
2539+
gs._curr.model_spec_to_gen_from.model_key,
25362540
"Sobol",
25372541
)
25382542

@@ -2643,8 +2647,8 @@ def test_get_pareto_optimal_points_objective_threshold_inference(
26432647
ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials(
26442648
num_trials=20, include_objective_thresholds=False
26452649
)
2646-
ax_client.generation_strategy._maybe_transition_to_next_node()
2647-
ax_client.generation_strategy._fit_current_model(
2650+
ax_client.standard_generation_strategy._maybe_transition_to_next_node()
2651+
ax_client.standard_generation_strategy._fit_current_model(
26482652
data=ax_client.experiment.lookup_data()
26492653
)
26502654

@@ -2855,7 +2859,8 @@ def test_with_hss(self) -> None:
28552859
# Make sure we actually tried a Botorch iteration and all the transforms it
28562860
# applies.
28572861
self.assertEqual(
2858-
ax_client.generation_strategy._generator_runs[-1]._model_key, "BoTorch"
2862+
ax_client.standard_generation_strategy._generator_runs[-1]._model_key,
2863+
"BoTorch",
28592864
)
28602865
self.assertEqual(len(ax_client.experiment.trials), 6)
28612866
ax_client.attach_trial(
@@ -2970,7 +2975,7 @@ def test_torch_device(self) -> None:
29702975
torch_device=device,
29712976
)
29722977
ax_client = get_branin_optimization(torch_device=device)
2973-
gpei_step_kwargs = ax_client.generation_strategy._steps[1].model_kwargs
2978+
gpei_step_kwargs = ax_client.standard_generation_strategy._steps[1].model_kwargs
29742979
self.assertEqual(gpei_step_kwargs["torch_device"], device)
29752980

29762981
def test_repr_function(
@@ -2999,7 +3004,7 @@ def test_gen_fixed_features(self) -> None:
29993004
name="fixed_features",
30003005
)
30013006
with mock.patch.object(
3002-
GenerationStrategy, "gen", wraps=ax_client.generation_strategy.gen
3007+
GenerationStrategy, "gen", wraps=ax_client.standard_generation_strategy.gen
30033008
) as mock_gen:
30043009
with self.subTest("fixed_features is None"):
30053010
params, idx = ax_client.get_next_trial()

0 commit comments

Comments
 (0)