Skip to content

Commit 909f9fd

Browse files
paschaifacebook-github-bot
authored andcommitted
Move get_improvement_over_baseline to the BestPointMixin (facebook#3156)
Summary: Pull Request resolved: facebook#3156 Differential Revision: D66472613
1 parent d1f5686 commit 909f9fd

File tree

8 files changed

+289
-225
lines changed

8 files changed

+289
-225
lines changed

ax/service/scheduler.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,71 +1021,6 @@ def summarize_final_result(self) -> OptimizationResult:
10211021
"""
10221022
return OptimizationResult()
10231023

1024-
def get_improvement_over_baseline(
1025-
self,
1026-
baseline_arm_name: str | None = None,
1027-
) -> float:
1028-
"""Returns the scalarized improvement over baseline, if applicable.
1029-
1030-
Returns:
1031-
For Single Objective cases, returns % improvement of objective.
1032-
Positive indicates improvement over baseline. Negative indicates regression.
1033-
For Multi Objective cases, throws NotImplementedError
1034-
"""
1035-
if self.experiment.is_moo_problem:
1036-
raise NotImplementedError(
1037-
"`get_improvement_over_baseline` not yet implemented"
1038-
+ " for multi-objective problems."
1039-
)
1040-
if not baseline_arm_name:
1041-
raise UserInputError(
1042-
"`get_improvement_over_baseline` missing required parameter: "
1043-
+ f"{baseline_arm_name=}, "
1044-
)
1045-
1046-
optimization_config = self.experiment.optimization_config
1047-
if not optimization_config:
1048-
raise ValueError("No optimization config found.")
1049-
1050-
objective_metric_name = optimization_config.objective.metric.name
1051-
1052-
# get the baseline trial
1053-
data = self.experiment.lookup_data().df
1054-
data = data[data["arm_name"] == baseline_arm_name]
1055-
if len(data) == 0:
1056-
raise UserInputError(
1057-
"`get_improvement_over_baseline`"
1058-
" could not find baseline arm"
1059-
f" `{baseline_arm_name}` in the experiment data."
1060-
)
1061-
data = data[data["metric_name"] == objective_metric_name]
1062-
baseline_value = data.iloc[0]["mean"]
1063-
1064-
# Find objective value of the best trial
1065-
idx, param, best_arm = none_throws(
1066-
self.get_best_trial(
1067-
optimization_config=optimization_config, use_model_predictions=False
1068-
)
1069-
)
1070-
best_arm = none_throws(best_arm)
1071-
best_obj_value = best_arm[0][objective_metric_name]
1072-
1073-
def percent_change(x: float, y: float, minimize: bool) -> float:
1074-
if x == 0:
1075-
raise ZeroDivisionError(
1076-
"Cannot compute percent improvement when denom is zero"
1077-
)
1078-
percent_change = (y - x) / abs(x) * 100
1079-
if minimize:
1080-
percent_change = -percent_change
1081-
return percent_change
1082-
1083-
return percent_change(
1084-
x=baseline_value,
1085-
y=best_obj_value,
1086-
minimize=optimization_config.objective.minimize,
1087-
)
1088-
10891024
def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool:
10901025
"""Checks if the failure rate (set in scheduler options) has been exceeded at
10911026
any point during the optimization.

ax/service/tests/scheduler_test_utils.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,6 +2197,8 @@ def test_get_improvement_over_baseline(self) -> None:
21972197
scheduler.experiment.trials[0].lookup_data().df["arm_name"].iloc[0]
21982198
)
21992199
percent_improvement = scheduler.get_improvement_over_baseline(
2200+
experiment=scheduler.experiment,
2201+
generation_strategy=scheduler.standard_generation_strategy,
22002202
baseline_arm_name=first_trial_name,
22012203
)
22022204

@@ -2209,11 +2211,7 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None:
22092211
self.branin_experiment.optimization_config = (
22102212
get_branin_multi_objective_optimization_config()
22112213
)
2212-
2213-
gs = self._get_generation_strategy_strategy_for_test(
2214-
experiment=self.branin_experiment,
2215-
generation_strategy=self.sobol_MBM_GS,
2216-
)
2214+
gs = self.sobol_MBM_GS
22172215

22182216
scheduler = Scheduler(
22192217
experiment=self.branin_experiment,
@@ -2227,6 +2225,8 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None:
22272225

22282226
with self.assertRaises(NotImplementedError):
22292227
scheduler.get_improvement_over_baseline(
2228+
experiment=scheduler.experiment,
2229+
generation_strategy=scheduler.standard_generation_strategy,
22302230
baseline_arm_name=None,
22312231
)
22322232

@@ -2236,10 +2236,7 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
22362236
experiment.name = f"{self.branin_experiment.name}_but_moo"
22372237
experiment.runner = self.runner
22382238

2239-
gs = self._get_generation_strategy_strategy_for_test(
2240-
experiment=experiment,
2241-
generation_strategy=self.two_sobol_steps_GS,
2242-
)
2239+
gs = self.two_sobol_steps_GS
22432240
scheduler = Scheduler(
22442241
experiment=self.branin_experiment, # Has runner and metrics.
22452242
generation_strategy=gs,
@@ -2253,6 +2250,8 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
22532250

22542251
with self.assertRaises(UserInputError):
22552252
scheduler.get_improvement_over_baseline(
2253+
experiment=scheduler.experiment,
2254+
generation_strategy=scheduler.standard_generation_strategy,
22562255
baseline_arm_name=None,
22572256
)
22582257

@@ -2267,19 +2266,20 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
22672266
scheduler.experiment = exp_copy
22682267

22692268
with self.assertRaises(ValueError):
2270-
scheduler.get_improvement_over_baseline(baseline_arm_name="baseline")
2269+
scheduler.get_improvement_over_baseline(
2270+
experiment=scheduler.experiment,
2271+
generation_strategy=scheduler.standard_generation_strategy,
2272+
baseline_arm_name="baseline",
2273+
)
22712274

22722275
def test_get_improvement_over_baseline_no_baseline(self) -> None:
22732276
"""Test that get_improvement_over_baseline returns UserInputError when
22742277
baseline is not found in data."""
22752278
n_total_trials = 8
2276-
gs = self._get_generation_strategy_strategy_for_test(
2277-
experiment=self.branin_experiment,
2278-
generation_strategy=self.two_sobol_steps_GS,
2279-
)
2280-
2279+
experiment = self.branin_experiment
2280+
gs = self.two_sobol_steps_GS
22812281
scheduler = Scheduler(
2282-
experiment=self.branin_experiment, # Has runner and metrics.
2282+
experiment=experiment, # Has runner and metrics.
22832283
generation_strategy=gs,
22842284
options=SchedulerOptions(
22852285
total_trials=n_total_trials,
@@ -2293,6 +2293,8 @@ def test_get_improvement_over_baseline_no_baseline(self) -> None:
22932293

22942294
with self.assertRaises(UserInputError):
22952295
scheduler.get_improvement_over_baseline(
2296+
experiment=experiment,
2297+
generation_strategy=gs,
22962298
baseline_arm_name="baseline_arm_not_in_data",
22972299
)
22982300

ax/service/tests/test_best_point_utils.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66

77
# pyre-strict
88

9+
import copy
910
import random
10-
from unittest.mock import MagicMock, patch
11+
from unittest.mock import MagicMock, patch, PropertyMock
1112

1213
import pandas as pd
1314
import torch
1415
from ax.core.arm import Arm
1516
from ax.core.batch_trial import BatchTrial
1617
from ax.core.data import Data
18+
from ax.core.experiment import Experiment
1719
from ax.core.generator_run import GeneratorRun
20+
from ax.core.metric import Metric
1821
from ax.core.objective import ScalarizedObjective
1922
from ax.core.optimization_config import OptimizationConfig
2023
from ax.core.outcome_constraint import OutcomeConstraint
@@ -32,10 +35,12 @@
3235
get_best_raw_objective_point,
3336
logger as best_point_logger,
3437
)
38+
from ax.service.utils.best_point_utils import select_baseline_arm
3539
from ax.utils.common.testutils import TestCase
3640
from ax.utils.testing.core_stubs import (
3741
get_branin_experiment,
3842
get_branin_metric,
43+
get_branin_search_space,
3944
get_experiment_with_observations,
4045
get_sobol,
4146
)
@@ -556,6 +561,75 @@ def test_is_row_feasible(self) -> None:
556561
df.index, feasible_series.index, check_names=False
557562
)
558563

564+
def test_compare_to_baseline_select_baseline_arm(self) -> None:
565+
OBJECTIVE_METRIC = "objective"
566+
true_obj_metric = Metric(name=OBJECTIVE_METRIC, lower_is_better=True)
567+
experiment = Experiment(
568+
search_space=get_branin_search_space(),
569+
tracking_metrics=[true_obj_metric],
570+
)
571+
572+
with patch.object(
573+
Experiment, "arms_by_name", new_callable=PropertyMock
574+
) as mock_arms_by_name:
575+
mock_arms_by_name.return_value = {"arm1": "value1", "arm2": "value2"}
576+
self.assertEqual(
577+
select_baseline_arm(
578+
experiment=experiment,
579+
baseline_arm_name="arm1",
580+
),
581+
("arm1", False),
582+
)
583+
584+
# specified baseline arm not in trial
585+
wrong_baseline_name = "wrong_baseline_name"
586+
with self.assertRaisesRegex(
587+
ValueError,
588+
"select_baseline_arm: baseline row: .*" + " not found in arms",
589+
):
590+
select_baseline_arm(
591+
experiment=experiment,
592+
baseline_arm_name=wrong_baseline_name,
593+
)
594+
595+
# status quo baseline arm
596+
experiment_with_status_quo = copy.deepcopy(experiment)
597+
experiment_with_status_quo.status_quo = Arm(
598+
name="status_quo",
599+
parameters={"x1": 0, "x2": 0},
600+
)
601+
self.assertEqual(
602+
select_baseline_arm(
603+
experiment=experiment_with_status_quo,
604+
baseline_arm_name=None,
605+
),
606+
("status_quo", False),
607+
)
608+
# first arm from trials
609+
custom_arm = Arm(name="m_0", parameters={"x1": 0.1, "x2": 0.2})
610+
experiment.new_trial().add_arm(custom_arm)
611+
self.assertEqual(
612+
select_baseline_arm(
613+
experiment=experiment,
614+
baseline_arm_name=None,
615+
),
616+
("m_0", True),
617+
)
618+
619+
# none selected
620+
experiment_with_no_valid_baseline = Experiment(
621+
search_space=get_branin_search_space(),
622+
tracking_metrics=[true_obj_metric],
623+
)
624+
625+
with self.assertRaisesRegex(
626+
ValueError, "select_baseline_arm: could not find valid baseline arm"
627+
):
628+
select_baseline_arm(
629+
experiment=experiment_with_no_valid_baseline,
630+
baseline_arm_name=None,
631+
)
632+
559633

560634
def _repeat_elements(list_to_replicate: list[bool], n_repeats: int) -> pd.Series:
561635
return pd.Series([item for item in list_to_replicate for _ in range(n_repeats)])

0 commit comments

Comments
 (0)