-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch mode for A X = B
with B a n x m
matrix
#115
Comments
So (a) I think you've made a few mistakes in the benchmarking, and (b) most Lineax/Optimistix/Diffrax routines all finish with an option to throw a runtime error if things have gone wrong, and this adds a measurable amount of overhead on microbenchmarks such as this. This can be disabled with So adjusting things a little, I get exactly comparable results between the two approaches. import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit
@jax.jit
def linalg_solve(A, B):
x = jnp.linalg.solve(A, B)
error = jnp.linalg.norm(B - (A @ x))
return x, error
@jax.jit
def lineax_solve(A, B):
operator = lx.MatrixLinearOperator(A)
def solve_single(b):
x = lx.linear_solve(operator, b, throw=False).value
return x
x = jax.vmap(solve_single, in_axes=1, out_axes=1)(B)
error = jnp.linalg.norm(B - (A @ x))
return x, error
def benchmark(method, func):
times = timeit.repeat(func, number=1, repeat=10)
_, error = func()
print(f"{method} solve error: {error:2e}")
print(f"{method} min time: {min(times)}\n")
N = 20
key = jr.PRNGKey(0)
A = jr.uniform(key, (N, N))
B = jnp.eye(N, N)
linalg_solve(A, B)
lineax_solve(A, B)
benchmark("linalg.solve", lambda: jax.block_until_ready(linalg_solve(A, B)))
benchmark("lineax", lambda: jax.block_until_ready(lineax_solve(A, B)))
# linalg.solve solve error: 7.080040e-06
# linalg.solve min time: 4.237500252202153e-05
#
# lineax solve error: 7.080040e-06
# lineax min time: 3.9375037886202335e-05 Notable changes here:
FWIW I've also trimmed out the use of |
Excellent, thanks for the details!
I am surprised that the |
|
Of course, makes total sense. Thanks for the details! |
Hey there,$A X = B$ is a $n \times m$ matrix. What is the best approach to efficiently reproduce this behaviour with
Some native JAX solvers such as
jnp.linalg.solve
andjax.scipy.sparse.linalg.gmres
nicely support batch mode, where the right hand side of the systemlineax
?I made a benchmark using
vmap
andlineax
, but this approach is is 4x slower:The text was updated successfully, but these errors were encountered: