-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AbstractGaussNewton now supports reverse-autodiff for Jacobians. #51
Conversation
In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a `jax.custom_vjp`, see #50
0d80f04
to
7768204
Compare
Hi Patrick, I tried this branch, and it did not work for my use case (parameter estimation for an ODE). I tried both
with both. I have an MWE if that would be useful for you. During the install, this is the commit it resolved to:
Hope you had a nice weekend! Johanna |
Nevermind, it works! I checked out #50 and realized that I need to pass |
Awesome, I'm glad to hear it! I've just merged this in, so this will appear in the next release of Optimistix. :) |
Hi! Not sure where else to post this, but I wanted to note it somewhere: on my problem, I get dramatic performance reductions when using reverse-mode autodiff inside of Specifically:
The documentation does note that forward-mode is usually more efficient, so this might just be further confirmation of that observation :) The solver I used is |
Hey there! Ah, indeed this might have been the case. It was hard to guess which was the greater overhead: FWIW this is motivating me to consider adding a forward-mode specific "adjoint". (I'd been holding off on this in the hopes that support could be added to |
Okay, completely untested, but something like the following should probably work: class ForwardMode(diffrax.AbstractAdjoint):
def loop(
self,
*,
solver,
throw,
passed_solver_state,
passed_controller_state,
**kwargs,
):
del throw, passed_solver_state, passed_controller_state
inner_while_loop = functools.partial(diffrax._adjoint._inner_loop, kind="lax")
outer_while_loop = ftunctoolspartial(diffrax._adjoint._outer_loop, kind="lax")
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, diffrax.AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(
lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none
)
final_state = self._loop(
solver=solver,
inner_while_loop=inner_while_loop,
outer_while_loop=outer_while_loop,
**kwargs,
)
return final_state by passing Let me know how this goes! |
Awesome, thank you! I'm happy to give it a proper go tomorrow. So far I am getting this error .../site-packages/diffrax/_integrate.py#line=457), in loop()
456 return new_state
--> 458 final_state = outer_while_loop(
459 cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
460 )
462 def _save_t1(subsaveat, save_state):
.../contextlib.py#line=80), in inner()
80 with self._recreate_cm():
---> 81 return func(*args, **kwds)
.../site-packages/equinox/internal/_loop/loop.py#line=102), in while_loop()
102 del cond_fun, body_fun, init_val
--> 103 _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
104 return final_val
JaxStackTraceBeforeTransformation: ValueError: Reverse-mode differentiation does not work for
lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using
fori_loop with static start/stop. I'm a little unsure where the reverse-mode differentiation is coming from here (this is with BTW, I am assuming that |
Ok, mini update: I haven't been able to dig into it any further, but I did get better benchmarks on my problem. It looks like forward-mode differentiation on |
Indeed it does! (Granted, possibly not by that much. Depends on your problem.) Do you have a quick MWE demonstrating the above crash? |
Ah even faster ODEs! Nice, that would be cool. |
Here comes the MWE! The adjoint works fine in import jax.numpy as jnp
import equinox as eqx
import diffrax
import optimistix
import functools
_is_none = lambda x: x is None # This function is needed in the forward adjoint
class ForwardMode(diffrax.AbstractAdjoint):
def loop(
self,
*,
solver,
throw,
passed_solver_state,
passed_controller_state,
**kwargs,
):
del throw, passed_solver_state, passed_controller_state
inner_while_loop = functools.partial(diffrax._adjoint._inner_loop, kind="lax")
outer_while_loop = functools.partial(diffrax._adjoint._outer_loop, kind="lax")
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, diffrax.AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(
lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none
)
final_state = self._loop(
solver=solver,
inner_while_loop=inner_while_loop,
outer_while_loop=outer_while_loop,
**kwargs,
)
return final_state
class ToyModel(eqx.Module):
"""Toy model that provides a simple interface to generate an ODE solution,
subject to its parameters.
"""
_term: diffrax.ODETerm
_t0: float
_t1: float
_dt0: float
_y0: float
_saveat: diffrax.SaveAt
_solver: diffrax.AbstractERK
_adjoint: diffrax.AbstractAdjoint
def __init__(self, ode_model, initial_state, times, solver, adjoint):
self._term = diffrax.ODETerm(ode_model)
self._y0 = initial_state
self._t0 = times[0]
self._t1 = times[-1]
self._dt0 = 0.01
self._saveat = diffrax.SaveAt(ts=times)
self._solver = solver
self._adjoint = adjoint
def __call__(self, param):
sol = diffrax.diffeqsolve(
self._term,
self._solver,
self._t0, self._t1, self._dt0, self._y0,
args=param,
saveat=self._saveat,
adjoint=self._adjoint,
)
return sol.ys
def estimate_parameters(initial_guess, model, data, solver, solver_options: dict = dict(jac='fwd')):
"""Function that estimates the parameters."""
def residuals(param, args):
model, data = args
fit = model(param)
res = data - fit
return res
sol = optimistix.least_squares(
residuals,
solver,
initial_guess,
args = (model, data),
options = solver_options,
)
return sol
# Create the model
def dydt(t, y, k): # Toy ODE
return - k * y
t = jnp.linspace(0, 10, 50)
y0 = 10.
model = ToyModel(dydt, y0, t, diffrax.Tsit5(), ForwardMode())
# Solve ODE
k = 0.5 # True value
ode_solution = model(k) # This runs without issue
# Now try solving for the parameters
k0 = 0.1
solver = optimistix.LevenbergMarquardt(atol=1e-09, rtol=1e-06)
lm_solution = estimate_parameters(k0, model, ode_solution, solver) # This fails |
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment). Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
Can you try #61? Hopefully that should fix things :) (It's also revealed an unrelated bug with complex numbers, but I don't think that should affect you.) |
Hi Patrick, amazing, thank you for the quick fix! |
whoops, my bad, I forgot to control for whether my laptop was plugged in :) Using Still pretty nice! I'll keep using the ForwardMode above. |
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment). Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment). Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a
jax.custom_vjp
, see #50