-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added
optimistix.compat.minimize
as a replacement for `jax.scipy.op…
…timize.minimize`
- Loading branch information
1 parent
ed36a3f
commit 3131b45
Showing
8 changed files
with
212 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Compatibility with `jax.scipy.optimize.minimize` | ||
|
||
The JAX API available at `jax.scipy.optimize.minimize` is being deprecated, in favour of domain-specific packages like Optimistix. As such Optimistix provides `optimistix.compat.minimize` as a drop in replacement. | ||
|
||
|
||
::: optimistix.compat.minimize | ||
|
||
--- | ||
|
||
::: optimistix.compat.OptimizeResults | ||
selection: | ||
members: | ||
false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ._impl import minimize as minimize, OptimizeResults as OptimizeResults |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from collections.abc import Callable, Mapping | ||
from typing import Any, NamedTuple, Optional, Union | ||
|
||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from .._minimise import minimise | ||
from .._misc import max_norm | ||
from .._solution import RESULTS | ||
from .._solver import BFGS | ||
|
||
|
||
class OptimizeResults(NamedTuple): | ||
"""Object holding optimization results. | ||
**Attributes:** | ||
- `x`: final solution. | ||
- `success`: ``True`` if optimization succeeded. | ||
- `status`: integer solver specific return code. 0 means converged (nominal), | ||
1=max BFGS iters reached, 3=other failure. | ||
- `fun`: final function value. | ||
- `jac`: final jacobian array. | ||
- `hess_inv`: final inverse Hessian estimate. | ||
- `nfev`: integer number of function calls used. | ||
- `njev`: integer number of gradient evaluations. | ||
- `nit`: integer number of iterations of the optimization algorithm. | ||
""" | ||
|
||
x: jax.Array | ||
success: Union[bool, jax.Array] | ||
status: Union[int, jax.Array] | ||
fun: jax.Array | ||
jac: jax.Array | ||
hess_inv: Optional[jax.Array] | ||
nfev: Union[int, jax.Array] | ||
njev: Union[int, jax.Array] | ||
nit: Union[int, jax.Array] | ||
|
||
|
||
def minimize( | ||
fun: Callable, | ||
x0: jax.Array, | ||
args: tuple = (), | ||
*, | ||
method: str, | ||
tol: Optional[float] = None, | ||
options: Optional[Mapping[str, Any]] = None, | ||
) -> OptimizeResults: | ||
"""Minimization of scalar function of one or more variables. | ||
!!! info | ||
This API is intended as a backward-compatibility drop-in for the now-deprecated | ||
`jax.scipy.optimize.minimize`. In line with that API, only `method="bfgs"` is | ||
supported. | ||
Whilst it's the same basic algorithm, the Optimistix implementation may do | ||
slightly different things under-the-hood. You may obtain slightly different | ||
(but still correct) results. | ||
**Arguments:** | ||
- `fun`: the objective function to be minimized, `fun(x, *args) -> float`, | ||
where `x` is a 1-D array with shape `(n,)` and `args` is a tuple | ||
of the fixed parameters needed to completely specify the function. | ||
`fun` must support differentiation. | ||
- `x0`: initial guess. Array of real elements of size `(n,)`, where `n` is | ||
the number of independent variables. | ||
- `args`: extra arguments passed to the objective function. | ||
- `method`: solver type. Currently only `"bfgs"` is supported. | ||
- `tol`: tolerance for termination. | ||
- `options`: a dictionary of solver options. The following options are supported: | ||
- `maxiter` (int): Maximum number of iterations to perform. Each iteration | ||
performs one function evaluation. Defaults to unlimited iterations. | ||
- `norm`: (callable `x -> float`): the norm to use when calculating errors. | ||
Defaults to a max norm. | ||
**Returns:** | ||
An [`optimistix.compat.OptimizeResults`][] object. | ||
""" | ||
if method.lower() != "bfgs": | ||
raise ValueError(f"Method {method} not recognized") | ||
if not eqx.is_array(x0) or x0.ndim != 1: | ||
raise ValueError("x0 must be a 1-dimensional array") | ||
if not isinstance(args, tuple): | ||
msg = "args argument to `optimistix.compat.minimize` must be a tuple, got {}" | ||
# TypeError, not ValueError, for compatibility with old | ||
# `jax.scipy.optimize.minimize`. | ||
raise TypeError(msg.format(args)) | ||
if tol is None: | ||
tol = 1e-5 | ||
if options is None: | ||
options = {} | ||
else: | ||
options = dict(options) | ||
max_steps = options.pop("maxiter", None) | ||
options.pop("norm", max_norm) | ||
if len(options) != 0: | ||
raise ValueError(f"Unsupported options: {set(options.keys())}") | ||
|
||
def wrapped_fn(y, args): | ||
return fun(y, *args) | ||
|
||
solver = BFGS(rtol=tol, atol=tol, norm=max_norm) | ||
sol = minimise(wrapped_fn, solver, x0, args, max_steps=max_steps, throw=False) | ||
status = jnp.where( | ||
sol.result == RESULTS.successful, | ||
0, | ||
jnp.where(sol.result == RESULTS.nonlinear_max_steps_reached, 1, 3), | ||
) | ||
return OptimizeResults( | ||
x=sol.value, | ||
success=sol.result == RESULTS.successful, | ||
status=status, | ||
fun=sol.state.f_info.f, | ||
jac=sol.state.f_info.grad, | ||
hess_inv=sol.state.f_info.hessian_inv.as_matrix(), | ||
nfev=sol.stats["num_steps"], | ||
njev=sol.state.num_accepted_steps, | ||
# Old JAX implementation counts each full line search as an iteration. | ||
nit=sol.state.num_accepted_steps, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import jax.numpy as jnp | ||
import jax.scipy.optimize as jsp_optimize | ||
import pytest | ||
|
||
import optimistix as optx | ||
|
||
from .helpers import beale, tree_allclose | ||
|
||
|
||
def _setup(): | ||
def fun(x, arg1, arg2, arg3): | ||
a, b = x | ||
return beale((a, b), (arg1, arg2, arg3)) | ||
|
||
args = (jnp.array(1.5), jnp.array(2.25), jnp.array(2.625)) | ||
x0 = jnp.array([2.0, 0.0]) | ||
return fun, args, x0 | ||
|
||
|
||
@pytest.mark.parametrize("method", ("bfgs", "BFGS")) | ||
def test_minimize(method): | ||
fun, args, x0 = _setup() | ||
result = optx.compat.minimize(fun, x0, args, method=method) | ||
assert tree_allclose(result.x, jnp.array([3.0, 0.5])) | ||
assert tree_allclose(fun(result.x, *args), jnp.array(0.0)) | ||
|
||
|
||
def test_errors(): | ||
fun, args, x0 = _setup() | ||
# remove test-time beartype wrapping | ||
minimize = optx.compat.minimize.__wrapped__.__wrapped__ | ||
with pytest.raises(ValueError): | ||
minimize(fun, [2.0, 0.0], args, method="bfgs") # pyright: ignore | ||
|
||
with pytest.raises(ValueError): | ||
minimize(fun, x0, args, method="foobar") | ||
|
||
with pytest.raises(TypeError): | ||
minimize(fun, x0, None, method="bfgs") # pyright: ignore | ||
|
||
|
||
def test_maxiter(): | ||
fun, args, x0 = _setup() | ||
out = optx.compat.minimize(fun, x0, args, method="bfgs", options=dict(maxiter=2)) | ||
assert not out.success | ||
assert out.status == 1 | ||
|
||
|
||
def test_compare(): | ||
fun, args, x0 = _setup() | ||
jax_out = jsp_optimize.minimize(fun, x0, args, method="bfgs") | ||
optx_out = optx.compat.minimize(fun, x0, args, method="bfgs") | ||
assert type(jax_out).__name__ == type(optx_out).__name__ | ||
assert tree_allclose(jax_out.x, optx_out.x) | ||
assert tree_allclose(jax_out.success, optx_out.success) | ||
assert tree_allclose(jax_out.status, optx_out.status) | ||
assert tree_allclose(jax_out.fun, optx_out.fun) | ||
assert tree_allclose(jax_out.jac, optx_out.jac, atol=1e-5, rtol=1e-5) | ||
assert tree_allclose(jax_out.hess_inv, optx_out.hess_inv, atol=1e-2, rtol=1e-2) | ||
# Don't compare number of iterations -- these may different between the two | ||
# implementations. |