Skip to content

Commit

Permalink
Added optimistix.compat.minimize as a replacement for `jax.scipy.op…
Browse files Browse the repository at this point in the history
…timize.minimize`
  • Loading branch information
patrick-kidger committed Oct 13, 2023
1 parent ed36a3f commit 3131b45
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 4 deletions.
13 changes: 13 additions & 0 deletions docs/api/compat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Compatibility with `jax.scipy.optimize.minimize`

The JAX API available at `jax.scipy.optimize.minimize` is being deprecated, in favour of domain-specific packages like Optimistix. As such Optimistix provides `optimistix.compat.minimize` as a drop in replacement.


::: optimistix.compat.minimize

---

::: optimistix.compat.OptimizeResults
selection:
members:
false
4 changes: 3 additions & 1 deletion docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ Optimistix doesn't try to reinvent the wheel! The Optax library is excellent. As

#### ...`jax.scipy.optimize.minimize`?

This is an API which is likely to be removed from JAX at some point, in favour of Optimistix and JAXopt. Don't use it. (Note that the core JAX API only supports minimisation, and only supports the BFGS algorithm.)
This is an API which is being removed from JAX, in favour of Optimistix and JAXopt. Don't use it. (Note that the core JAX API only supports minimisation, and only supports the BFGS algorithm.)

Optimistix supports [`optimistix.compat.minimize`][] as a drop-in replacement for this API.

## How to debug a solver that is failing to converge, or producing an error?

Expand Down
4 changes: 3 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ nav:
- 'index.md'
- 'how-to-choose.md'
- 'abstract.md'
- 'faq.md'
- Examples:
- Root Finding: 'examples/root_find.ipynb'
- Optimise an ODE: 'examples/optimise_diffeq.ipynb'
Expand All @@ -121,3 +120,6 @@ nav:
- 'api/searches/searches.md'
- 'api/searches/descents.md'
- 'api/searches/function_info.md'
- Misc:
- 'faq.md'
- 'api/compat.md'
2 changes: 1 addition & 1 deletion optimistix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import importlib.metadata

from . import internal as internal
from . import compat as compat, internal as internal
from ._adjoint import (
AbstractAdjoint as AbstractAdjoint,
ImplicitAdjoint as ImplicitAdjoint,
Expand Down
6 changes: 5 additions & 1 deletion optimistix/_solver/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import lineax as lx
from equinox import AbstractVar
from equinox.internal import ω
from jaxtyping import Array, Bool, PyTree, Scalar
from jaxtyping import Array, Bool, Int, PyTree, Scalar

from .._custom_types import Aux, DescentState, Fn, SearchState, Y
from .._minimise import AbstractMinimiser
Expand Down Expand Up @@ -152,6 +152,8 @@ class _BFGSState(eqx.Module, Generic[Y, Aux, SearchState, DescentState, _Hessian
# Used for termination
terminate: Bool[Array, ""]
result: RESULTS
# Used in compat.py
num_accepted_steps: Int[Array, ""]


class AbstractBFGS(AbstractMinimiser[Y, Aux, _BFGSState], Generic[Y, Aux, _Hessian]):
Expand Down Expand Up @@ -199,6 +201,7 @@ def init(
descent_state=self.descent.init(y, f_info_struct),
terminate=jnp.array(False),
result=RESULTS.successful,
num_accepted_steps=jnp.array(0),
)

def step(
Expand Down Expand Up @@ -265,6 +268,7 @@ def rejected(descent_state):
descent_state=descent_state,
terminate=terminate,
result=result,
num_accepted_steps=state.num_accepted_steps + accept,
)
return y, state, aux

Expand Down
1 change: 1 addition & 0 deletions optimistix/compat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._impl import minimize as minimize, OptimizeResults as OptimizeResults
125 changes: 125 additions & 0 deletions optimistix/compat/_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from collections.abc import Callable, Mapping
from typing import Any, NamedTuple, Optional, Union

import equinox as eqx
import jax
import jax.numpy as jnp

from .._minimise import minimise
from .._misc import max_norm
from .._solution import RESULTS
from .._solver import BFGS


class OptimizeResults(NamedTuple):
"""Object holding optimization results.
**Attributes:**
- `x`: final solution.
- `success`: ``True`` if optimization succeeded.
- `status`: integer solver specific return code. 0 means converged (nominal),
1=max BFGS iters reached, 3=other failure.
- `fun`: final function value.
- `jac`: final jacobian array.
- `hess_inv`: final inverse Hessian estimate.
- `nfev`: integer number of function calls used.
- `njev`: integer number of gradient evaluations.
- `nit`: integer number of iterations of the optimization algorithm.
"""

x: jax.Array
success: Union[bool, jax.Array]
status: Union[int, jax.Array]
fun: jax.Array
jac: jax.Array
hess_inv: Optional[jax.Array]
nfev: Union[int, jax.Array]
njev: Union[int, jax.Array]
nit: Union[int, jax.Array]


def minimize(
fun: Callable,
x0: jax.Array,
args: tuple = (),
*,
method: str,
tol: Optional[float] = None,
options: Optional[Mapping[str, Any]] = None,
) -> OptimizeResults:
"""Minimization of scalar function of one or more variables.
!!! info
This API is intended as a backward-compatibility drop-in for the now-deprecated
`jax.scipy.optimize.minimize`. In line with that API, only `method="bfgs"` is
supported.
Whilst it's the same basic algorithm, the Optimistix implementation may do
slightly different things under-the-hood. You may obtain slightly different
(but still correct) results.
**Arguments:**
- `fun`: the objective function to be minimized, `fun(x, *args) -> float`,
where `x` is a 1-D array with shape `(n,)` and `args` is a tuple
of the fixed parameters needed to completely specify the function.
`fun` must support differentiation.
- `x0`: initial guess. Array of real elements of size `(n,)`, where `n` is
the number of independent variables.
- `args`: extra arguments passed to the objective function.
- `method`: solver type. Currently only `"bfgs"` is supported.
- `tol`: tolerance for termination.
- `options`: a dictionary of solver options. The following options are supported:
- `maxiter` (int): Maximum number of iterations to perform. Each iteration
performs one function evaluation. Defaults to unlimited iterations.
- `norm`: (callable `x -> float`): the norm to use when calculating errors.
Defaults to a max norm.
**Returns:**
An [`optimistix.compat.OptimizeResults`][] object.
"""
if method.lower() != "bfgs":
raise ValueError(f"Method {method} not recognized")
if not eqx.is_array(x0) or x0.ndim != 1:
raise ValueError("x0 must be a 1-dimensional array")
if not isinstance(args, tuple):
msg = "args argument to `optimistix.compat.minimize` must be a tuple, got {}"
# TypeError, not ValueError, for compatibility with old
# `jax.scipy.optimize.minimize`.
raise TypeError(msg.format(args))
if tol is None:
tol = 1e-5
if options is None:
options = {}
else:
options = dict(options)
max_steps = options.pop("maxiter", None)
options.pop("norm", max_norm)
if len(options) != 0:
raise ValueError(f"Unsupported options: {set(options.keys())}")

def wrapped_fn(y, args):
return fun(y, *args)

solver = BFGS(rtol=tol, atol=tol, norm=max_norm)
sol = minimise(wrapped_fn, solver, x0, args, max_steps=max_steps, throw=False)
status = jnp.where(
sol.result == RESULTS.successful,
0,
jnp.where(sol.result == RESULTS.nonlinear_max_steps_reached, 1, 3),
)
return OptimizeResults(
x=sol.value,
success=sol.result == RESULTS.successful,
status=status,
fun=sol.state.f_info.f,
jac=sol.state.f_info.grad,
hess_inv=sol.state.f_info.hessian_inv.as_matrix(),
nfev=sol.stats["num_steps"],
njev=sol.state.num_accepted_steps,
# Old JAX implementation counts each full line search as an iteration.
nit=sol.state.num_accepted_steps,
)
61 changes: 61 additions & 0 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import jax.numpy as jnp
import jax.scipy.optimize as jsp_optimize
import pytest

import optimistix as optx

from .helpers import beale, tree_allclose


def _setup():
def fun(x, arg1, arg2, arg3):
a, b = x
return beale((a, b), (arg1, arg2, arg3))

args = (jnp.array(1.5), jnp.array(2.25), jnp.array(2.625))
x0 = jnp.array([2.0, 0.0])
return fun, args, x0


@pytest.mark.parametrize("method", ("bfgs", "BFGS"))
def test_minimize(method):
fun, args, x0 = _setup()
result = optx.compat.minimize(fun, x0, args, method=method)
assert tree_allclose(result.x, jnp.array([3.0, 0.5]))
assert tree_allclose(fun(result.x, *args), jnp.array(0.0))


def test_errors():
fun, args, x0 = _setup()
# remove test-time beartype wrapping
minimize = optx.compat.minimize.__wrapped__.__wrapped__
with pytest.raises(ValueError):
minimize(fun, [2.0, 0.0], args, method="bfgs") # pyright: ignore

with pytest.raises(ValueError):
minimize(fun, x0, args, method="foobar")

with pytest.raises(TypeError):
minimize(fun, x0, None, method="bfgs") # pyright: ignore


def test_maxiter():
fun, args, x0 = _setup()
out = optx.compat.minimize(fun, x0, args, method="bfgs", options=dict(maxiter=2))
assert not out.success
assert out.status == 1


def test_compare():
fun, args, x0 = _setup()
jax_out = jsp_optimize.minimize(fun, x0, args, method="bfgs")
optx_out = optx.compat.minimize(fun, x0, args, method="bfgs")
assert type(jax_out).__name__ == type(optx_out).__name__
assert tree_allclose(jax_out.x, optx_out.x)
assert tree_allclose(jax_out.success, optx_out.success)
assert tree_allclose(jax_out.status, optx_out.status)
assert tree_allclose(jax_out.fun, optx_out.fun)
assert tree_allclose(jax_out.jac, optx_out.jac, atol=1e-5, rtol=1e-5)
assert tree_allclose(jax_out.hess_inv, optx_out.hess_inv, atol=1e-2, rtol=1e-2)
# Don't compare number of iterations -- these may different between the two
# implementations.

0 comments on commit 3131b45

Please sign in to comment.