From a1743c14d834f7a180bd1715b667a33d384f7d54 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 18 Mar 2024 21:42:10 +0100 Subject: [PATCH] AbstractGaussNewton now supports reverse-autodiff for Jacobians. In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a `jax.custom_vjp`, see https://github.com/patrick-kidger/optimistix/issues/50 --- optimistix/_solver/dogleg.py | 7 ++++ optimistix/_solver/gauss_newton.py | 46 ++++++++++++++++++++--- optimistix/_solver/levenberg_marquardt.py | 14 +++++++ tests/test_least_squares.py | 22 +++++++++++ 4 files changed, 83 insertions(+), 6 deletions(-) diff --git a/optimistix/_solver/dogleg.py b/optimistix/_solver/dogleg.py index cec8247..ff92d2f 100644 --- a/optimistix/_solver/dogleg.py +++ b/optimistix/_solver/dogleg.py @@ -224,6 +224,13 @@ class Dogleg(AbstractGaussNewton[Y, Out, Aux], strict=True): The distinguishing feature of this algorithm is the "dog leg" shape of its descent path, in which it begins by moving in the steepest descent direction, and then switches to moving in the Newton direction. + + Supports the following `options`: + + - `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the + Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is + usually more efficient. Changing this can be useful when the target function has + a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation. """ rtol: float diff --git a/optimistix/_solver/gauss_newton.py b/optimistix/_solver/gauss_newton.py index c4e8cdc..17d8a06 100644 --- a/optimistix/_solver/gauss_newton.py +++ b/optimistix/_solver/gauss_newton.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any, Generic, Optional, Union +from typing import Any, Generic, Literal, Optional, Union import equinox as eqx import jax @@ -164,10 +164,28 @@ class _GaussNewtonState( def _make_f_info( - fn: Callable[[Y, Args], tuple[Any, Aux]], y: Y, args: Args, tags: frozenset + fn: Callable[[Y, Args], tuple[Any, Aux]], + y: Y, + args: Args, + tags: frozenset, + jac: Literal["fwd", "bwd"], ) -> tuple[FunctionInfo.ResidualJac, Aux]: - f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: fn(_y, args), y, has_aux=True) - jac_eval = lx.FunctionLinearOperator(lin_fn, jax.eval_shape(lambda: y), tags) + if jac == "fwd": + f_eval, lin_fn, aux_eval = jax.linearize( + lambda _y: fn(_y, args), y, has_aux=True + ) + jac_eval = lx.FunctionLinearOperator(lin_fn, jax.eval_shape(lambda: y), tags) + elif jac == "bwd": + # Materialise the Jacobian in this case. + def _for_jacrev(_y): + f_eval, aux_eval = fn(_y, args) + return f_eval, (f_eval, aux_eval) + + jac_pytree, (f_eval, aux_eval) = jax.jacrev(_for_jacrev, has_aux=True)(y) + output_structure = jax.eval_shape(lambda: f_eval) + jac_eval = lx.PyTreeLinearOperator(jac_pytree, output_structure, tags) + else: + raise ValueError("Only `jac='fwd'` or `jac='bwd'` are valid.") return FunctionInfo.ResidualJac(f_eval, jac_eval), aux_eval @@ -187,6 +205,13 @@ class AbstractGaussNewton( - `descent`: `AbstractDescent` - `search`: `AbstractSearch` - `verbose`: `frozenset[str]` + + Supports the following `options`: + + - `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the + Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is + usually more efficient. Changing this can be useful when the target function has + a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation. """ rtol: AbstractVar[float] @@ -208,7 +233,8 @@ def init( aux_struct: PyTree[jax.ShapeDtypeStruct], tags: frozenset[object], ) -> _GaussNewtonState: - f_info_struct, _ = eqx.filter_eval_shape(_make_f_info, fn, y, args, tags) + jac = options.get("jac", "fwd") + f_info_struct, _ = eqx.filter_eval_shape(_make_f_info, fn, y, args, tags, jac) f_info = tree_full_like(f_info_struct, 0, allow_static=True) return _GaussNewtonState( first_step=jnp.array(True), @@ -233,7 +259,8 @@ def step( state: _GaussNewtonState, tags: frozenset[object], ) -> tuple[Y, _GaussNewtonState, Aux]: - f_eval_info, aux_eval = _make_f_info(fn, state.y_eval, args, tags) + jac = options.get("jac", "fwd") + f_eval_info, aux_eval = _make_f_info(fn, state.y_eval, args, tags, jac) # We have a jaxpr in `f_info.jac`, which are compared by identity. Here we # arrange to use the same one so that downstream equality checks (e.g. in the # `filter_cond` below) @@ -360,6 +387,13 @@ class GaussNewton(AbstractGaussNewton[Y, Out, Aux], strict=True): Note that regularised approaches like [`optimistix.LevenbergMarquardt`][] are usually preferred instead. + + Supports the following `options`: + + - `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the + Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is + usually more efficient. Changing this can be useful when the target function has + a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation. """ rtol: float diff --git a/optimistix/_solver/levenberg_marquardt.py b/optimistix/_solver/levenberg_marquardt.py index 031a943..9e27c56 100644 --- a/optimistix/_solver/levenberg_marquardt.py +++ b/optimistix/_solver/levenberg_marquardt.py @@ -264,6 +264,13 @@ class LevenbergMarquardt(AbstractGaussNewton[Y, Out, Aux], strict=True): region around the current point. This is a good algorithm for many least squares problems. + + Supports the following `options`: + + - `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the + Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is + usually more efficient. Changing this can be useful when the target function has + a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation. """ rtol: float @@ -316,6 +323,13 @@ class IndirectLevenbergMarquardt(AbstractGaussNewton[Y, Out, Aux], strict=True): Generally speaking [`optimistix.LevenbergMarquardt`][] is preferred, as it performs nearly the same algorithm, without the computational overhead of an extra (scalar) nonlinear solve. + + Supports the following `options`: + + - `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the + Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is + usually more efficient. Changing this can be useful when the target function has + a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation. """ rtol: float diff --git a/tests/test_least_squares.py b/tests/test_least_squares.py index d51de02..5a5ac0b 100644 --- a/tests/test_least_squares.py +++ b/tests/test_least_squares.py @@ -127,3 +127,25 @@ def least_squares(x, dynamic_args, *, adjoint): # assert tree_allclose(out2, expected_out, atol=atol, rtol=rtol) # assert tree_allclose(t_expected_out2, t_expected_out, atol=atol, rtol=rtol) # assert tree_allclose(t_out2, t_expected_out, atol=atol, rtol=rtol) + + +def test_gauss_newton_jacrev(): + @jax.custom_vjp + def f(y, _): + return dict(bar=y["foo"] ** 2) + + def f_fwd(y, _): + return f(y, None), jnp.sign(y["foo"]) + + def f_bwd(sign, g): + return dict(foo=sign * g["bar"]), None + + f.defvjp(f_fwd, f_bwd) + + solver = optx.LevenbergMarquardt(rtol=1e-8, atol=1e-8) + y0 = dict(foo=jnp.arange(3.0)) + out = optx.least_squares(f, solver, y0, options=dict(jac="bwd"), max_steps=512) + assert tree_allclose(out.value, dict(foo=jnp.zeros(3)), rtol=1e-3, atol=1e-2) + + with pytest.raises(TypeError, match="forward-mode autodiff"): + optx.least_squares(f, solver, y0, options=dict(jac="fwd"), max_steps=512)