Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion botorch/acquisition/knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def evaluate(self, X: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
kwargs: Additional keyword arguments. This includes the options for
optimization of the inner problem, i.e. `num_restarts`, `raw_samples`,
an `options` dictionary to be passed on to the optimization helpers, and
a `scipy_options` dictionary to be passed to `scipy.optimize.minimize`.
a `scipy_options` dictionary to be passed to `scipy.minimize`.

Returns:
A Tensor of shape `b`. For t-batch b, the q-KG value of the design
Expand Down
12 changes: 4 additions & 8 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ def gen_candidates_scipy(

Optimizes an acquisition function starting from a set of initial candidates
using `scipy.optimize.minimize` via a numpy converter.
We use SLSQP, if constraints are present, and LBFGS-B otherwise.
As `scipy.optimize.minimize` does not support optimizating a batch of problems, we
treat optimizing a set of candidates as a single optimization problem by
summing together their acquisition values.

Args:
initial_conditions: Starting points for optimization, with shape
Expand All @@ -102,7 +98,7 @@ def gen_candidates_scipy(
`optimize_acqf()`. The constraints will later be passed to the scipy
solver.
options: Options used to control the optimization including "method"
and "maxiter". Select method for `scipy.optimize.minimize` using the
and "maxiter". Select method for `scipy.minimize` using the
"method" key. By default uses L-BFGS-B for box-constrained problems
and SLSQP if inequality or equality constraints are present. If
`with_grad=False`, then we use a two-point finite difference estimate
Expand Down Expand Up @@ -664,13 +660,13 @@ def _process_scipy_result(res: OptimizeResult, options: dict[str, Any]) -> None:
or "Iteration limit reached" in res.message
):
logger.info(
"`scipy.optimize.minimize` exited by reaching the iteration limit of "
"`scipy.minimize` exited by reaching the iteration limit of "
f"`maxiter: {options.get('maxiter')}`."
)
elif "EVALUATIONS EXCEEDS LIMIT" in res.message:
logger.info(
"`scipy.optimize.minimize` exited by reaching the function evaluation "
f"limit of `maxfun: {options.get('maxfun')}`."
"`scipy.minimize` exited by reaching the function evaluation limit of "
f"`maxfun: {options.get('maxfun')}`."
)
elif "Optimization timed out after" in res.message:
logger.info(res.message)
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
super().__init__()

self.model = (
_SingleTaskVariationalGP(num_outputs=num_outputs, *args, **kwargs)
_SingleTaskVariationalGP(*args, num_outputs=num_outputs, **kwargs)
if model is None
else model
)
Expand Down
19 changes: 14 additions & 5 deletions botorch/models/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,15 @@ def __init__(
self.model = model

# Validate model compatibility
if isinstance(model, ModelList) and len(model.models) != model.num_outputs:
raise UnsupportedError(
"A model-list of multi-output models is not supported."
)
if isinstance(model, ModelList):
# Check if any model in the list is multi-output
# Use _num_outputs which doesn't include batch dimensions
for m in model.models:
num_outs = getattr(m, "_num_outputs", getattr(m, "num_outputs", 1))
if num_outs > 1:
raise UnsupportedError(
"A model-list of multi-output models is not supported."
)

# Initialize path generation parameters
self.sample_shape = Size() if sample_shape is None else sample_shape
Expand Down Expand Up @@ -322,7 +327,11 @@ def forward(self, X: Tensor) -> Tensor:
return self._path(X).unsqueeze(-1)
elif isinstance(self.model, ModelList):
# For model list, stack the path outputs
return torch.stack(self._path(X), dim=-1)
path_outputs = self._path(X)
if len(path_outputs) == 0:
# Handle empty model list
return torch.empty(X.shape[0], 0, device=X.device, dtype=X.dtype)
return torch.stack(path_outputs, dim=-1)
else:
# For multi-output models
return self._path(X.unsqueeze(-3)).transpose(-1, -2)
Expand Down
2 changes: 2 additions & 0 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ def _apply_noise(
self,
X: Tensor,
mvn: MultivariateNormal,
num_outputs: int,
observation_noise: bool | Tensor,
) -> MultivariateNormal:
"""Adds the observation noise to the posterior.
Expand Down Expand Up @@ -1066,6 +1067,7 @@ def posterior(
mvn = self._apply_noise(
X=X_full,
mvn=mvn,
num_outputs=num_outputs,
observation_noise=observation_noise,
)
# If single-output, return the posterior of a single-output model
Expand Down
15 changes: 15 additions & 0 deletions botorch/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
"check_min_max_scaling",
"check_standardization",
"fantasize",
"get_train_inputs",
"get_train_targets",
"gpt_posterior_settings",
"multioutput_to_batch_mode_transform",
"mod_batch_shape",
Expand All @@ -38,3 +40,16 @@
"extract_targets_and_noise_single_output",
"restore_targets_and_noise_single_output",
]


# Lazy import to avoid circular dependencies
def __getattr__(name):
if name == "get_train_inputs":
from botorch.models.utils.helpers import get_train_inputs

return get_train_inputs
elif name == "get_train_targets":
from botorch.models.utils.helpers import get_train_targets

return get_train_targets
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
166 changes: 166 additions & 0 deletions botorch/models/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Any, List, overload, Tuple, TYPE_CHECKING

import torch
from botorch.utils.dispatcher import Dispatcher
from torch import Tensor

if TYPE_CHECKING:
from botorch.models.model import Model, ModelList

GetTrainInputs = Dispatcher("get_train_inputs")
GetTrainTargets = Dispatcher("get_train_targets")


@overload
def get_train_inputs(model: Model, transformed: bool = False) -> Tuple[Tensor, ...]:
pass # pragma: no cover


@overload
def get_train_inputs(model: ModelList, transformed: bool = False) -> List[...]:
pass # pragma: no cover


def get_train_inputs(model: Any, transformed: bool = False):
"""Get training inputs from a model, with optional transformation handling.

Args:
model: A BoTorch Model or ModelList.
transformed: If True, return the transformed inputs. If False, return the
original (untransformed) inputs.

Returns:
A tuple of training input tensors for Model, or a list of tuples for ModelList.
"""
# Lazy import to avoid circular dependencies
_register_get_train_inputs()
return GetTrainInputs(model, transformed=transformed)


def _register_get_train_inputs():
"""Register dispatcher implementations for get_train_inputs (lazy)."""
# Only register once
if hasattr(_register_get_train_inputs, "_registered"):
return
_register_get_train_inputs._registered = True

from botorch.models.approximate_gp import SingleTaskVariationalGP
from botorch.models.model import Model, ModelList

@GetTrainInputs.register(Model)
def _get_train_inputs_Model(
model: Model, transformed: bool = False
) -> Tuple[Tensor]:
if not transformed:
original_train_input = getattr(model, "_original_train_inputs", None)
if torch.is_tensor(original_train_input):
return (original_train_input,)

(X,) = model.train_inputs
transform = getattr(model, "input_transform", None)
if transform is None:
return (X,)

if model.training:
return (transform.forward(X) if transformed else X,)
return (X if transformed else transform.untransform(X),)

@GetTrainInputs.register(SingleTaskVariationalGP)
def _get_train_inputs_SingleTaskVariationalGP(
model: SingleTaskVariationalGP, transformed: bool = False
) -> Tuple[Tensor]:
(X,) = model.model.train_inputs
if model.training != transformed:
return (X,)

transform = getattr(model, "input_transform", None)
if transform is None:
return (X,)

return (transform.forward(X) if model.training else transform.untransform(X),)

@GetTrainInputs.register(ModelList)
def _get_train_inputs_ModelList(
model: ModelList, transformed: bool = False
) -> List[...]:
return [get_train_inputs(m, transformed=transformed) for m in model.models]


@overload
def get_train_targets(model: Model, transformed: bool = False) -> Tensor:
pass # pragma: no cover


@overload
def get_train_targets(model: ModelList, transformed: bool = False) -> List[...]:
pass # pragma: no cover


def get_train_targets(model: Any, transformed: bool = False):
"""Get training targets from a model, with optional transformation handling.

Args:
model: A BoTorch Model or ModelList.
transformed: If True, return the transformed targets. If False, return the
original (untransformed) targets.

Returns:
Training target tensors for Model, or a list of tensors for ModelList.
"""
# Lazy import to avoid circular dependencies
_register_get_train_targets()
return GetTrainTargets(model, transformed=transformed)


def _register_get_train_targets():
"""Register dispatcher implementations for get_train_targets (lazy)."""
# Only register once
if hasattr(_register_get_train_targets, "_registered"):
return
_register_get_train_targets._registered = True

from botorch.models.approximate_gp import SingleTaskVariationalGP
from botorch.models.model import Model, ModelList

@GetTrainTargets.register(Model)
def _get_train_targets_Model(model: Model, transformed: bool = False) -> Tensor:
Y = model.train_targets

# Note: Avoid using `get_output_transform` here since it creates a Module
transform = getattr(model, "outcome_transform", None)
if transformed or transform is None:
return Y

if model.num_outputs == 1:
return transform.untransform(Y.unsqueeze(-1))[0].squeeze(-1)
return transform.untransform(Y.transpose(-2, -1))[0].transpose(-2, -1)

@GetTrainTargets.register(SingleTaskVariationalGP)
def _get_train_targets_SingleTaskVariationalGP(
model: Model, transformed: bool = False
) -> Tensor:
Y = model.model.train_targets
transform = getattr(model, "outcome_transform", None)
if transformed or transform is None:
return Y

if model.num_outputs == 1:
return transform.untransform(Y.unsqueeze(-1))[0].squeeze(-1)

# SingleTaskVariationalGP.__init__ doesn't bring the
# multioutput dimension inside
return transform.untransform(Y)[0]

@GetTrainTargets.register(ModelList)
def _get_train_targets_ModelList(
model: ModelList, transformed: bool = False
) -> List[...]:
return [get_train_targets(m, transformed=transformed) for m in model.models]
4 changes: 2 additions & 2 deletions botorch/optim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def scipy_minimize(
bounds: A dictionary mapping parameter names to lower and upper bounds.
callback: A callable taking `parameters` and an OptimizationResult as arguments.
x0: An optional initialization vector passed to scipy.optimize.minimize.
method: Solver type, passed along to scipy.optimize.minimize.
options: Dictionary of solver options, passed along to scipy.optimize.minimize.
method: Solver type, passed along to scipy.minimize.
options: Dictionary of solver options, passed along to scipy.minimize.
timeout_sec: Timeout in seconds to wait before aborting the optimization loop
if not converged (will return the best found solution thus far).

Expand Down
4 changes: 2 additions & 2 deletions botorch/optim/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def fit_gpytorch_mll_scipy(
Responsible for setting the `grad` attributes of `parameters`. If no closure
is provided, one will be obtained by calling `get_loss_closure_with_grads`.
closure_kwargs: Keyword arguments passed to `closure`.
method: Solver type, passed along to scipy.optimize.minimize.
options: Dictionary of solver options, passed along to scipy.optimize.minimize.
method: Solver type, passed along to scipy.minimize.
options: Dictionary of solver options, passed along to scipy.minimize.
callback: Optional callback taking `parameters` and an OptimizationResult as its
sole arguments.
timeout_sec: Timeout in seconds after which to terminate the fitting loop
Expand Down
36 changes: 5 additions & 31 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,29 +603,7 @@ def optimize_acqf(
retry_on_optimization_warning: bool = True,
**ic_gen_kwargs: Any,
) -> tuple[Tensor, Tensor]:
r"""Optimize the acquisition function for a single or multiple joint candidates.

A high-level description (missing exceptions for special setups):

This function optimizes the acquisition function `acq_function` in two steps:

i) It will sample `raw_samples` random points using Sobol sampling in the bounds
`bounds` and pass on the "best" `num_restarts` many.
The default way to find these "best" is via `gen_batch_initial_conditions`
(deviating for some acq functions, see `get_ic_generator`),
which by default performs Boltzmann sampling on the acquisition function value
(The behavior of step (i) can be further controlled by specifying `ic_generator`
or `batch_initial_conditions`.)

ii) A batch of the `num_restarts` points (or joint sets of points)
with the highest acquisition values in the previous step are then further
optimized. This is by default done by LBFGS-B optimization, if no constraints are
present, and SLSQP, if constraints are present (can be changed to
other optmizers via `gen_candidates`).

While the optimization procedure runs on CPU by default for this function,
the acq_function can be implemented on GPU and simply move the inputs
to GPU internally.
r"""Generate a set of candidates via multi-start optimization.

Args:
acq_function: An AcquisitionFunction.
Expand All @@ -634,13 +612,10 @@ def optimize_acqf(
+inf, respectively).
q: The number of candidates.
num_restarts: The number of starting points for multistart acquisition
function optimization. Even though the name suggests this happens
sequentually, it is done in parallel (using batched evaluations)
for up to `options.batch_limit` candidates (by default completely parallel).
function optimization.
raw_samples: The number of samples for initialization. This is required
if `batch_initial_conditions` is not specified.
options: Options for both optimization, passed to `gen_candidates`,
and initialization, passed to the `ic_generator` via the `options` kwarg.
options: Options for candidate generation.
inequality_constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
Expand Down Expand Up @@ -685,9 +660,8 @@ def optimize_acqf(
acquisition values) given a tensor of initial conditions and an
acquisition function. Other common inputs include lower and upper bounds
and a dictionary of options, but refer to the documentation of specific
generation functions (e.g., botorch.optim.optimize.gen_candidates_scipy
and botorch.generation.gen.gen_candidates_torch) for method-specific
inputs. Default: `gen_candidates_scipy`
generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
for method-specific inputs. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization for optimizing multiple joint candidates (q > 1).
acq_function_sequence: A list of acquisition functions to be optimized
Expand Down
Loading