From 3131b45d724c5b66b4976c44da22a6f4bada281f Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 13 Oct 2023 10:26:26 -0700 Subject: [PATCH] Added `optimistix.compat.minimize` as a replacement for `jax.scipy.optimize.minimize` --- docs/api/compat.md | 13 ++++ docs/faq.md | 4 +- mkdocs.yml | 4 +- optimistix/__init__.py | 2 +- optimistix/_solver/bfgs.py | 6 +- optimistix/compat/__init__.py | 1 + optimistix/compat/_impl.py | 125 ++++++++++++++++++++++++++++++++++ tests/test_compat.py | 61 +++++++++++++++++ 8 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 docs/api/compat.md create mode 100644 optimistix/compat/__init__.py create mode 100644 optimistix/compat/_impl.py create mode 100644 tests/test_compat.py diff --git a/docs/api/compat.md b/docs/api/compat.md new file mode 100644 index 0000000..39a7cdc --- /dev/null +++ b/docs/api/compat.md @@ -0,0 +1,13 @@ +# Compatibility with `jax.scipy.optimize.minimize` + +The JAX API available at `jax.scipy.optimize.minimize` is being deprecated, in favour of domain-specific packages like Optimistix. As such Optimistix provides `optimistix.compat.minimize` as a drop in replacement. + + +::: optimistix.compat.minimize + +--- + +::: optimistix.compat.OptimizeResults + selection: + members: + false diff --git a/docs/faq.md b/docs/faq.md index 6cacc97..6d0c84c 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -24,7 +24,9 @@ Optimistix doesn't try to reinvent the wheel! The Optax library is excellent. As #### ...`jax.scipy.optimize.minimize`? -This is an API which is likely to be removed from JAX at some point, in favour of Optimistix and JAXopt. Don't use it. (Note that the core JAX API only supports minimisation, and only supports the BFGS algorithm.) +This is an API which is being removed from JAX, in favour of Optimistix and JAXopt. Don't use it. (Note that the core JAX API only supports minimisation, and only supports the BFGS algorithm.) + +Optimistix supports [`optimistix.compat.minimize`][] as a drop-in replacement for this API. ## How to debug a solver that is failing to converge, or producing an error? diff --git a/mkdocs.yml b/mkdocs.yml index be7fd98..0843cb0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -100,7 +100,6 @@ nav: - 'index.md' - 'how-to-choose.md' - 'abstract.md' - - 'faq.md' - Examples: - Root Finding: 'examples/root_find.ipynb' - Optimise an ODE: 'examples/optimise_diffeq.ipynb' @@ -121,3 +120,6 @@ nav: - 'api/searches/searches.md' - 'api/searches/descents.md' - 'api/searches/function_info.md' + - Misc: + - 'faq.md' + - 'api/compat.md' diff --git a/optimistix/__init__.py b/optimistix/__init__.py index 417b463..da0f468 100644 --- a/optimistix/__init__.py +++ b/optimistix/__init__.py @@ -14,7 +14,7 @@ import importlib.metadata -from . import internal as internal +from . import compat as compat, internal as internal from ._adjoint import ( AbstractAdjoint as AbstractAdjoint, ImplicitAdjoint as ImplicitAdjoint, diff --git a/optimistix/_solver/bfgs.py b/optimistix/_solver/bfgs.py index d41840c..26a766c 100644 --- a/optimistix/_solver/bfgs.py +++ b/optimistix/_solver/bfgs.py @@ -22,7 +22,7 @@ import lineax as lx from equinox import AbstractVar from equinox.internal import ω -from jaxtyping import Array, Bool, PyTree, Scalar +from jaxtyping import Array, Bool, Int, PyTree, Scalar from .._custom_types import Aux, DescentState, Fn, SearchState, Y from .._minimise import AbstractMinimiser @@ -152,6 +152,8 @@ class _BFGSState(eqx.Module, Generic[Y, Aux, SearchState, DescentState, _Hessian # Used for termination terminate: Bool[Array, ""] result: RESULTS + # Used in compat.py + num_accepted_steps: Int[Array, ""] class AbstractBFGS(AbstractMinimiser[Y, Aux, _BFGSState], Generic[Y, Aux, _Hessian]): @@ -199,6 +201,7 @@ def init( descent_state=self.descent.init(y, f_info_struct), terminate=jnp.array(False), result=RESULTS.successful, + num_accepted_steps=jnp.array(0), ) def step( @@ -265,6 +268,7 @@ def rejected(descent_state): descent_state=descent_state, terminate=terminate, result=result, + num_accepted_steps=state.num_accepted_steps + accept, ) return y, state, aux diff --git a/optimistix/compat/__init__.py b/optimistix/compat/__init__.py new file mode 100644 index 0000000..b106823 --- /dev/null +++ b/optimistix/compat/__init__.py @@ -0,0 +1 @@ +from ._impl import minimize as minimize, OptimizeResults as OptimizeResults diff --git a/optimistix/compat/_impl.py b/optimistix/compat/_impl.py new file mode 100644 index 0000000..45a0352 --- /dev/null +++ b/optimistix/compat/_impl.py @@ -0,0 +1,125 @@ +from collections.abc import Callable, Mapping +from typing import Any, NamedTuple, Optional, Union + +import equinox as eqx +import jax +import jax.numpy as jnp + +from .._minimise import minimise +from .._misc import max_norm +from .._solution import RESULTS +from .._solver import BFGS + + +class OptimizeResults(NamedTuple): + """Object holding optimization results. + + **Attributes:** + + - `x`: final solution. + - `success`: ``True`` if optimization succeeded. + - `status`: integer solver specific return code. 0 means converged (nominal), + 1=max BFGS iters reached, 3=other failure. + - `fun`: final function value. + - `jac`: final jacobian array. + - `hess_inv`: final inverse Hessian estimate. + - `nfev`: integer number of function calls used. + - `njev`: integer number of gradient evaluations. + - `nit`: integer number of iterations of the optimization algorithm. + """ + + x: jax.Array + success: Union[bool, jax.Array] + status: Union[int, jax.Array] + fun: jax.Array + jac: jax.Array + hess_inv: Optional[jax.Array] + nfev: Union[int, jax.Array] + njev: Union[int, jax.Array] + nit: Union[int, jax.Array] + + +def minimize( + fun: Callable, + x0: jax.Array, + args: tuple = (), + *, + method: str, + tol: Optional[float] = None, + options: Optional[Mapping[str, Any]] = None, +) -> OptimizeResults: + """Minimization of scalar function of one or more variables. + + !!! info + + This API is intended as a backward-compatibility drop-in for the now-deprecated + `jax.scipy.optimize.minimize`. In line with that API, only `method="bfgs"` is + supported. + + Whilst it's the same basic algorithm, the Optimistix implementation may do + slightly different things under-the-hood. You may obtain slightly different + (but still correct) results. + + **Arguments:** + + - `fun`: the objective function to be minimized, `fun(x, *args) -> float`, + where `x` is a 1-D array with shape `(n,)` and `args` is a tuple + of the fixed parameters needed to completely specify the function. + `fun` must support differentiation. + - `x0`: initial guess. Array of real elements of size `(n,)`, where `n` is + the number of independent variables. + - `args`: extra arguments passed to the objective function. + - `method`: solver type. Currently only `"bfgs"` is supported. + - `tol`: tolerance for termination. + - `options`: a dictionary of solver options. The following options are supported: + - `maxiter` (int): Maximum number of iterations to perform. Each iteration + performs one function evaluation. Defaults to unlimited iterations. + - `norm`: (callable `x -> float`): the norm to use when calculating errors. + Defaults to a max norm. + + **Returns:** + + An [`optimistix.compat.OptimizeResults`][] object. + """ + if method.lower() != "bfgs": + raise ValueError(f"Method {method} not recognized") + if not eqx.is_array(x0) or x0.ndim != 1: + raise ValueError("x0 must be a 1-dimensional array") + if not isinstance(args, tuple): + msg = "args argument to `optimistix.compat.minimize` must be a tuple, got {}" + # TypeError, not ValueError, for compatibility with old + # `jax.scipy.optimize.minimize`. + raise TypeError(msg.format(args)) + if tol is None: + tol = 1e-5 + if options is None: + options = {} + else: + options = dict(options) + max_steps = options.pop("maxiter", None) + options.pop("norm", max_norm) + if len(options) != 0: + raise ValueError(f"Unsupported options: {set(options.keys())}") + + def wrapped_fn(y, args): + return fun(y, *args) + + solver = BFGS(rtol=tol, atol=tol, norm=max_norm) + sol = minimise(wrapped_fn, solver, x0, args, max_steps=max_steps, throw=False) + status = jnp.where( + sol.result == RESULTS.successful, + 0, + jnp.where(sol.result == RESULTS.nonlinear_max_steps_reached, 1, 3), + ) + return OptimizeResults( + x=sol.value, + success=sol.result == RESULTS.successful, + status=status, + fun=sol.state.f_info.f, + jac=sol.state.f_info.grad, + hess_inv=sol.state.f_info.hessian_inv.as_matrix(), + nfev=sol.stats["num_steps"], + njev=sol.state.num_accepted_steps, + # Old JAX implementation counts each full line search as an iteration. + nit=sol.state.num_accepted_steps, + ) diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 0000000..e9495f5 --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,61 @@ +import jax.numpy as jnp +import jax.scipy.optimize as jsp_optimize +import pytest + +import optimistix as optx + +from .helpers import beale, tree_allclose + + +def _setup(): + def fun(x, arg1, arg2, arg3): + a, b = x + return beale((a, b), (arg1, arg2, arg3)) + + args = (jnp.array(1.5), jnp.array(2.25), jnp.array(2.625)) + x0 = jnp.array([2.0, 0.0]) + return fun, args, x0 + + +@pytest.mark.parametrize("method", ("bfgs", "BFGS")) +def test_minimize(method): + fun, args, x0 = _setup() + result = optx.compat.minimize(fun, x0, args, method=method) + assert tree_allclose(result.x, jnp.array([3.0, 0.5])) + assert tree_allclose(fun(result.x, *args), jnp.array(0.0)) + + +def test_errors(): + fun, args, x0 = _setup() + # remove test-time beartype wrapping + minimize = optx.compat.minimize.__wrapped__.__wrapped__ + with pytest.raises(ValueError): + minimize(fun, [2.0, 0.0], args, method="bfgs") # pyright: ignore + + with pytest.raises(ValueError): + minimize(fun, x0, args, method="foobar") + + with pytest.raises(TypeError): + minimize(fun, x0, None, method="bfgs") # pyright: ignore + + +def test_maxiter(): + fun, args, x0 = _setup() + out = optx.compat.minimize(fun, x0, args, method="bfgs", options=dict(maxiter=2)) + assert not out.success + assert out.status == 1 + + +def test_compare(): + fun, args, x0 = _setup() + jax_out = jsp_optimize.minimize(fun, x0, args, method="bfgs") + optx_out = optx.compat.minimize(fun, x0, args, method="bfgs") + assert type(jax_out).__name__ == type(optx_out).__name__ + assert tree_allclose(jax_out.x, optx_out.x) + assert tree_allclose(jax_out.success, optx_out.success) + assert tree_allclose(jax_out.status, optx_out.status) + assert tree_allclose(jax_out.fun, optx_out.fun) + assert tree_allclose(jax_out.jac, optx_out.jac, atol=1e-5, rtol=1e-5) + assert tree_allclose(jax_out.hess_inv, optx_out.hess_inv, atol=1e-2, rtol=1e-2) + # Don't compare number of iterations -- these may different between the two + # implementations.