Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: How does Lineax write the JVP rule for linear solvers? #134

Open
SNMS95 opened this issue Feb 14, 2025 · 2 comments
Open

Query: How does Lineax write the JVP rule for linear solvers? #134

SNMS95 opened this issue Feb 14, 2025 · 2 comments
Labels
question User queries

Comments

@SNMS95
Copy link

SNMS95 commented Feb 14, 2025

Hey guys,

Thanks for providing a nice ecosystem (equinox, optimistic, lineax...) to enable complex scientific computations.
I am in the process of learning to write custom AD rules for linear solvers, specifically for solvers from external libraries.
I know that the JVP-rule is simple to write and more useful atm (since JAX has not exposed the control of transpose rules jax-ml/jax#9129 or jax-ml/jax#17840), since it gives both forward and reverse-mode AD.

The JVP rule would look like

@jax.custom_jvp
def solve(A, b):
    # Solve Ax=b
    x = some_solver(A, b)
    return x

@solve.def_jvp()
def solve_jvp(primals, tangents):
    A_dot, b_dot = tangents
    x = solve(*primals)
    # A_dot x + A x_dot = b_dot
    # x_dot = A_inv (b_dot - A_dot x)
    x_dot = some_solver(A, b_dot - A_dot @ x)
    return x, x_dot

The issue is that this fails when the solver is external even if we use pure_callback.
I had a discussion about it in jax (jax-ml/jax#25528)
But the end result was that it would require a full-blown primitive.

But if my understanding is correct, Lineax does this somehow.
Can you give insights on how this was achieved.

P.S. The end goal for me to also add sparsity into the mix

@patrick-kidger
Copy link
Owner

Lineax's JVP rule is defined here. We actually do define a custom primitive, rather than using jax.custom_jvp, because we need to define a custom transposition rule. :)

Fortunately for your case, this has all been handled for you in a solver-agnostic way. You should be able to just implement a lineax.AbstractLinearSolver and then things will work for you from there.

@patrick-kidger patrick-kidger added the question User queries label Feb 15, 2025
@SNMS95
Copy link
Author

SNMS95 commented Feb 17, 2025

Thanks Patrick.
I will attempt to do that in this thread (and create a MWE), and then close this so that it could be useful for others as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants