Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch mode for A X = B with B a n x m matrix #115

Open
vboussange opened this issue Nov 19, 2024 · 4 comments
Open

Batch mode for A X = B with B a n x m matrix #115

vboussange opened this issue Nov 19, 2024 · 4 comments
Labels
question User queries

Comments

@vboussange
Copy link

vboussange commented Nov 19, 2024

Hey there,
Some native JAX solvers such as jnp.linalg.solve and jax.scipy.sparse.linalg.gmres nicely support batch mode, where the right hand side of the system $A X = B$ is a $n \times m$ matrix. What is the best approach to efficiently reproduce this behaviour with lineax?

I made a benchmark using vmap and lineax, but this approach is is 4x slower:

import jax.numpy as jnp
import jax.random as jr
from jax import vmap, jit
import lineax as lx
import timeit


N = 20
key = jr.PRNGKey(0)
A = jr.uniform(key, (N, N))
B = jnp.eye(N, N)

@jit
def linalg_solve():
    x = jnp.linalg.solve(A, B)
    error = jnp.linalg.norm(B - (A @ x))
    return x, error

def lineax_solve(solver):
    operator = lx.MatrixLinearOperator(A)
    state = solver.init(operator, options={})
    def solve_single(b):
        x = lx.linear_solve(operator, b, solver=solver, state=state).value
        return x
    x = vmap(solve_single, in_axes=1, out_axes=1)(B)
    error = jnp.linalg.norm(B - (A @ x))
    return x, error

def benchmark(method, func):
    time_taken = timeit.timeit(func, number=10) / 10
    _, error = func()
    print(f"{method} solve error: {error:2e}")
    print(f"{method} average time: {time_taken * 1e3:.2f} ms\n")

benchmark("linalg.solve", linalg_solve)
# linalg.solve solve error: 6.581411e-06
# linalg.solve average time: 0.03 ms

myfun = jit(lambda: lineax_solve(lx.LU()))
benchmark("lineax", myfun)
# lineax solve error: 6.581411e-06
# lineax average time: 0.13 ms
@patrick-kidger
Copy link
Owner

So (a) I think you've made a few mistakes in the benchmarking, and (b) most Lineax/Optimistix/Diffrax routines all finish with an option to throw a runtime error if things have gone wrong, and this adds a measurable amount of overhead on microbenchmarks such as this. This can be disabled with throw=False.

So adjusting things a little, I get exactly comparable results between the two approaches.

import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit

@jax.jit
def linalg_solve(A, B):
    x = jnp.linalg.solve(A, B)
    error = jnp.linalg.norm(B - (A @ x))
    return x, error

@jax.jit
def lineax_solve(A, B):
    operator = lx.MatrixLinearOperator(A)
    def solve_single(b):
        x = lx.linear_solve(operator, b, throw=False).value
        return x
    x = jax.vmap(solve_single, in_axes=1, out_axes=1)(B)
    error = jnp.linalg.norm(B - (A @ x))
    return x, error

def benchmark(method, func):
    times = timeit.repeat(func, number=1, repeat=10)
    _, error = func()
    print(f"{method} solve error: {error:2e}")
    print(f"{method} min time: {min(times)}\n")

N = 20
key = jr.PRNGKey(0)
A = jr.uniform(key, (N, N))
B = jnp.eye(N, N)

linalg_solve(A, B)
lineax_solve(A, B)

benchmark("linalg.solve", lambda: jax.block_until_ready(linalg_solve(A, B)))
benchmark("lineax", lambda: jax.block_until_ready(lineax_solve(A, B)))

# linalg.solve solve error: 7.080040e-06
# linalg.solve min time: 4.237500252202153e-05
#
# lineax solve error: 7.080040e-06
# lineax min time: 3.9375037886202335e-05

Notable changes here:

  • Using throw=False to disable Lineax's checking for success (and just silently returning NaNs if things go wrong).
  • Using jax.block_until_ready.
  • Compiling prior to evaluating, so that we don't measure differences in compilation speed.
  • Using min with repeat=10, rather than the mean, over the evaluation times. As benchmarking noise is one-sided then this is usually the correct aggregation method for microbenchmarks.
  • Actually passing in inputs to the JIT'd region. What you've written here could in principle be entirely constant-folded by the compiler.

FWIW I've also trimmed out the use of state and the explicit lineax.LU() solver, as the former is done already inside the solve and the latter is the default.

@patrick-kidger patrick-kidger added the question User queries label Nov 20, 2024
@vboussange
Copy link
Author

vboussange commented Nov 21, 2024

Excellent, thanks for the details!

FWIW I've also trimmed out the use of state and the explicit lineax.LU() solver, as the former is done already inside the solve and the latter is the default.

I am surprised that the vmap is not triggering multiple internal init?

@patrick-kidger
Copy link
Owner

I am surprised that the vmap is not triggering multiple internal init?

init is called only on the non-vmap'd input A, so it won't be vmap'd.

@vboussange
Copy link
Author

Of course, makes total sense. Thanks for the details!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants