diff --git a/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py b/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py index d7c2e8d706a..104145a9978 100644 --- a/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py +++ b/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py @@ -68,11 +68,14 @@ def setUp(self) -> None: t.mark_completed() self.data = self.exp.fetch_data() + self._refresh_modelbridge() + + def _refresh_modelbridge(self) -> None: self.modelbridge = ModelBridge( search_space=self.exp.search_space, model=Model(), experiment=self.exp, - data=self.data, + data=self.exp.lookup_data(), status_quo_name="status_quo", ) @@ -139,14 +142,16 @@ def test_single_trial_is_not_transformed(self) -> None: obs2 = tf.transform_observations(obs) self.assertEqual(obs, obs2) - def test_taget_trial_index(self) -> None: + def test_target_trial_index(self) -> None: sobol = get_sobol(search_space=self.exp.search_space) - self.exp.new_batch_trial(generator_run=sobol.gen(2)) + self.exp.new_batch_trial(generator_run=sobol.gen(2), optimize_for_power=True) t = self.exp.trials[1] t = checked_cast(BatchTrial, t) t.mark_running(no_runner_required=True) self.exp.attach_data(get_branin_data_batch(batch=checked_cast(BatchTrial, t))) + self._refresh_modelbridge() + observations = observations_from_data( experiment=self.exp, data=self.exp.lookup_data(), @@ -157,17 +162,4 @@ def test_taget_trial_index(self) -> None: observations=observations, modelbridge=self.modelbridge, ) - self.assertEqual(t.default_trial_idx, 1) - - with mock.patch( - "ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index", - return_value=10, - ): - t = TransformToNewSQ( - search_space=self.exp.search_space, - observations=observations, - modelbridge=self.modelbridge, - ) - - self.assertEqual(t.default_trial_idx, 10) diff --git a/ax/modelbridge/transforms/transform_to_new_sq.py b/ax/modelbridge/transforms/transform_to_new_sq.py index fb4f11e9e58..a4880095bf1 100644 --- a/ax/modelbridge/transforms/transform_to_new_sq.py +++ b/ax/modelbridge/transforms/transform_to_new_sq.py @@ -19,7 +19,6 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.outcome_constraint import OutcomeConstraint from ax.core.search_space import SearchSpace -from ax.core.utils import get_target_trial_index from ax.modelbridge.transforms.relativize import BaseRelativize, get_metric_index from ax.models.types import TConfig from ax.utils.common.typeutils import checked_cast @@ -71,9 +70,7 @@ def __init__( and modelbridge is not None and modelbridge._experiment is not None ): - target_trial_index = get_target_trial_index( - experiment=modelbridge._experiment - ) + target_trial_index = max(self.status_quo_data_by_trial.keys()) if target_trial_index is not None: self.default_trial_idx: int = checked_cast(int, target_trial_index)