Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractGaussNewton now supports reverse-autodiff for Jacobians. #51

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions optimistix/_solver/dogleg.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@ class Dogleg(AbstractGaussNewton[Y, Out, Aux], strict=True):
The distinguishing feature of this algorithm is the "dog leg" shape of its descent
path, in which it begins by moving in the steepest descent direction, and then
switches to moving in the Newton direction.

Supports the following `options`:

- `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the
Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is
usually more efficient. Changing this can be useful when the target function has
a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation.
"""

rtol: float
Expand Down
46 changes: 40 additions & 6 deletions optimistix/_solver/gauss_newton.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Any, Generic, Optional, Union
from typing import Any, Generic, Literal, Optional, Union

import equinox as eqx
import jax
Expand Down Expand Up @@ -164,10 +164,28 @@ class _GaussNewtonState(


def _make_f_info(
fn: Callable[[Y, Args], tuple[Any, Aux]], y: Y, args: Args, tags: frozenset
fn: Callable[[Y, Args], tuple[Any, Aux]],
y: Y,
args: Args,
tags: frozenset,
jac: Literal["fwd", "bwd"],
) -> tuple[FunctionInfo.ResidualJac, Aux]:
f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: fn(_y, args), y, has_aux=True)
jac_eval = lx.FunctionLinearOperator(lin_fn, jax.eval_shape(lambda: y), tags)
if jac == "fwd":
f_eval, lin_fn, aux_eval = jax.linearize(
lambda _y: fn(_y, args), y, has_aux=True
)
jac_eval = lx.FunctionLinearOperator(lin_fn, jax.eval_shape(lambda: y), tags)
elif jac == "bwd":
# Materialise the Jacobian in this case.
def _for_jacrev(_y):
f_eval, aux_eval = fn(_y, args)
return f_eval, (f_eval, aux_eval)

jac_pytree, (f_eval, aux_eval) = jax.jacrev(_for_jacrev, has_aux=True)(y)
output_structure = jax.eval_shape(lambda: f_eval)
jac_eval = lx.PyTreeLinearOperator(jac_pytree, output_structure, tags)
else:
raise ValueError("Only `jac='fwd'` or `jac='bwd'` are valid.")
return FunctionInfo.ResidualJac(f_eval, jac_eval), aux_eval


Expand All @@ -187,6 +205,13 @@ class AbstractGaussNewton(
- `descent`: `AbstractDescent`
- `search`: `AbstractSearch`
- `verbose`: `frozenset[str]`

Supports the following `options`:

- `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the
Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is
usually more efficient. Changing this can be useful when the target function has
a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation.
"""

rtol: AbstractVar[float]
Expand All @@ -208,7 +233,8 @@ def init(
aux_struct: PyTree[jax.ShapeDtypeStruct],
tags: frozenset[object],
) -> _GaussNewtonState:
f_info_struct, _ = eqx.filter_eval_shape(_make_f_info, fn, y, args, tags)
jac = options.get("jac", "fwd")
f_info_struct, _ = eqx.filter_eval_shape(_make_f_info, fn, y, args, tags, jac)
f_info = tree_full_like(f_info_struct, 0, allow_static=True)
return _GaussNewtonState(
first_step=jnp.array(True),
Expand All @@ -233,7 +259,8 @@ def step(
state: _GaussNewtonState,
tags: frozenset[object],
) -> tuple[Y, _GaussNewtonState, Aux]:
f_eval_info, aux_eval = _make_f_info(fn, state.y_eval, args, tags)
jac = options.get("jac", "fwd")
f_eval_info, aux_eval = _make_f_info(fn, state.y_eval, args, tags, jac)
# We have a jaxpr in `f_info.jac`, which are compared by identity. Here we
# arrange to use the same one so that downstream equality checks (e.g. in the
# `filter_cond` below)
Expand Down Expand Up @@ -360,6 +387,13 @@ class GaussNewton(AbstractGaussNewton[Y, Out, Aux], strict=True):

Note that regularised approaches like [`optimistix.LevenbergMarquardt`][] are
usually preferred instead.

Supports the following `options`:

- `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the
Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is
usually more efficient. Changing this can be useful when the target function has
a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation.
"""

rtol: float
Expand Down
14 changes: 14 additions & 0 deletions optimistix/_solver/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,13 @@ class LevenbergMarquardt(AbstractGaussNewton[Y, Out, Aux], strict=True):
region around the current point.

This is a good algorithm for many least squares problems.

Supports the following `options`:

- `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the
Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is
usually more efficient. Changing this can be useful when the target function has
a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation.
"""

rtol: float
Expand Down Expand Up @@ -316,6 +323,13 @@ class IndirectLevenbergMarquardt(AbstractGaussNewton[Y, Out, Aux], strict=True):
Generally speaking [`optimistix.LevenbergMarquardt`][] is preferred, as it performs
nearly the same algorithm, without the computational overhead of an extra (scalar)
nonlinear solve.

Supports the following `options`:

- `jac`: whether to use forward- or reverse-mode autodifferentiation to compute the
Jacobian. Can be either `"fwd"` or `"bwd"`. Defaults to `"fwd"`, which is
usually more efficient. Changing this can be useful when the target function has
a `jax.custom_vjp`, and so does not support forward-mode autodifferentiation.
"""

rtol: float
Expand Down
22 changes: 22 additions & 0 deletions tests/test_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,25 @@ def least_squares(x, dynamic_args, *, adjoint):
# assert tree_allclose(out2, expected_out, atol=atol, rtol=rtol)
# assert tree_allclose(t_expected_out2, t_expected_out, atol=atol, rtol=rtol)
# assert tree_allclose(t_out2, t_expected_out, atol=atol, rtol=rtol)


def test_gauss_newton_jacrev():
@jax.custom_vjp
def f(y, _):
return dict(bar=y["foo"] ** 2)

def f_fwd(y, _):
return f(y, None), jnp.sign(y["foo"])

def f_bwd(sign, g):
return dict(foo=sign * g["bar"]), None

f.defvjp(f_fwd, f_bwd)

solver = optx.LevenbergMarquardt(rtol=1e-8, atol=1e-8)
y0 = dict(foo=jnp.arange(3.0))
out = optx.least_squares(f, solver, y0, options=dict(jac="bwd"), max_steps=512)
assert tree_allclose(out.value, dict(foo=jnp.zeros(3)), rtol=1e-3, atol=1e-2)

with pytest.raises(TypeError, match="forward-mode autodiff"):
optx.least_squares(f, solver, y0, options=dict(jac="fwd"), max_steps=512)
Loading