You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for providing a nice ecosystem (equinox, optimistic, lineax...) to enable complex scientific computations.
I am in the process of learning to write custom AD rules for linear solvers, specifically for solvers from external libraries.
I know that the JVP-rule is simple to write and more useful atm (since JAX has not exposed the control of transpose rules jax-ml/jax#9129 or jax-ml/jax#17840), since it gives both forward and reverse-mode AD.
The JVP rule would look like
@jax.custom_jvp
def solve(A, b):
# Solve Ax=b
x = some_solver(A, b)
return x
@solve.def_jvp()
def solve_jvp(primals, tangents):
A_dot, b_dot = tangents
x = solve(*primals)
# A_dot x + A x_dot = b_dot
# x_dot = A_inv (b_dot - A_dot x)
x_dot = some_solver(A, b_dot - A_dot @ x)
return x, x_dot
The issue is that this fails when the solver is external even if we use pure_callback.
I had a discussion about it in jax (jax-ml/jax#25528)
But the end result was that it would require a full-blown primitive.
But if my understanding is correct, Lineax does this somehow.
Can you give insights on how this was achieved.
P.S. The end goal for me to also add sparsity into the mix
The text was updated successfully, but these errors were encountered:
Lineax's JVP rule is defined here. We actually do define a custom primitive, rather than using jax.custom_jvp, because we need to define a custom transposition rule. :)
Fortunately for your case, this has all been handled for you in a solver-agnostic way. You should be able to just implement a lineax.AbstractLinearSolver and then things will work for you from there.
Hey guys,
Thanks for providing a nice ecosystem (equinox, optimistic, lineax...) to enable complex scientific computations.
I am in the process of learning to write custom AD rules for linear solvers, specifically for solvers from external libraries.
I know that the JVP-rule is simple to write and more useful atm (since JAX has not exposed the control of transpose rules jax-ml/jax#9129 or jax-ml/jax#17840), since it gives both forward and reverse-mode AD.
The JVP rule would look like
The issue is that this fails when the solver is external even if we use
pure_callback
.I had a discussion about it in jax (jax-ml/jax#25528)
But the end result was that it would require a full-blown primitive.
But if my understanding is correct, Lineax does this somehow.
Can you give insights on how this was achieved.
P.S. The end goal for me to also add sparsity into the mix
The text was updated successfully, but these errors were encountered: