Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched JAX linear solves bugged for large batches #79

Open
ma-gilles opened this issue Jan 24, 2024 · 5 comments
Open

Batched JAX linear solves bugged for large batches #79

ma-gilles opened this issue Jan 24, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@ma-gilles
Copy link

Hello,

I opened a similar issue on the main JAX (jax-ml/jax#19431) but I thought it may get more attention here.

The batched JAX linear solves seem to be bugged for large batches on GPU, even if it can still comfortably fit in GPU memory.
In short, if you try to solve a bunch of linear system, then the JAX LU/Cholesky solver will sometime return NaN's/other problems but not throw an error or warning. The SVD-based solve seems to work better, though it also fails if you get close enough to filling the full GPU memory. The QR-based solve is too slow for me to test at large batch size, strangely. The lineax solves has the same behavior, although it does throw an error upon seeing NaNs.

Below is a test and output, where solving Ax = b where A is the identity and b is all ones returns NaNs. I am curious if someone can reproduce this behavior and has any ideas on what to do.

Thank you for making this nice library!
Best,
Marc

import jax
import lineax as lx

import jax.numpy as jnp
from jax import random
device = jax.local_devices()[0]
print('on device:', device)

m = 10

batched_solve_lu = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.LU()).value)
batched_solve_SVD = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.SVD()).value)

solve_fns = [jax.scipy.linalg.solve, batched_solve_SVD, batched_solve_lu]

for solve_fn in solve_fns:
    for n in [ int(1e6), int(1e7)]:
        A = jnp.repeat(jnp.identity(m)[None], n, axis = 0)

        x = jnp.ones([n,m])
        b = jax.lax.batch_matmul(A,x[...,None])[...,0]
        
        x_solved = solve_fn(A,b)
        print(f"Average error with n ={n}, {jnp.mean(jnp.linalg.norm(x - x_solved, axis=-1))} ")
        print("Memory info ", device.memory_stats())

Output:

on device: cuda:0
Average error with n =1000000, 0.0 
Memory info  {'bytes_in_use': 520001024, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 804000512, 'largest_free_block_bytes': 0, 'num_allocs': 29, 'peak_bytes_in_use': 1324001536, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =10000000, nan 
Memory info  {'bytes_in_use': 5200002048, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 58, 'peak_bytes_in_use': 13280002560, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =1000000, 0.0 
Memory info  {'bytes_in_use': 520000000, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 98, 'peak_bytes_in_use': 13280002560, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =10000000, 0.0 
Memory info  {'bytes_in_use': 5200000000, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 136, 'peak_bytes_in_use': 18120001024, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =1000000, 0.0 
Memory info  {'bytes_in_use': 560002560, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 174, 'peak_bytes_in_use': 18120001024, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
2024-01-24 15:17:43.130711: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CpuCallback error: EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

At:
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_errors.py(70): raises
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(258): _flat_callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(52): pure_callback_impl
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(188): _callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py(2267): _wrapped_callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py(1152): __call__
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1151): _pjit_call_impl_python
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1195): call_impl_cache_miss
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1211): _pjit_call_impl
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(869): process_primitive
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1427): _pjit_batcher
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/batching.py(433): process_primitive
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(166): _python_pjit_helper
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(255): cache_miss
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(200): _call
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_module.py(935): __call__
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(206): __call__
  /tmp/ipykernel_3270959/20536111.py(9): <lambda>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/linear_util.py(190): call_wrapped
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/api.py(1260): vmap_f
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
  /tmp/ipykernel_3270959/20536111.py(22): <module>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3526): run_code
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3466): run_ast_nodes
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3284): run_cell_async
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3079): _run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3024): run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/zmqshell.py(546): run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/ipkernel.py(422): do_execute
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(740): execute_request
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(412): dispatch_shell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(505): process_one
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(516): dispatch_queue
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/events.py(80): _run
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(1905): _run_once
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(601): run_forever
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/tornado/platform/asyncio.py(195): start
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelapp.py(736): start
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/traitlets/config/application.py(1053): launch_instance
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel_launcher.py(17): <module>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(87): _run_code
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(197): _run_module_as_main

2024-01-24 15:17:43.130798: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2711] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

At:
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_errors.py(70): raises
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(258): _flat_callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(52): pure_callback_impl
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(188): _callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py(2267): _wrapped_callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py(1152): __call__
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1151): _pjit_call_impl_python
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1195): call_impl_cache_miss
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1211): _pjit_call_impl
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(869): process_primitive
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1427): _pjit_batcher
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/batching.py(433): process_primitive
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(166): _python_pjit_helper
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(255): cache_miss
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(200): _call
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_module.py(935): __call__
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(206): __call__
  /tmp/ipykernel_3270959/20536111.py(9): <lambda>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/linear_util.py(190): call_wrapped
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/api.py(1260): vmap_f
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
  /tmp/ipykernel_3270959/20536111.py(22): <module>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3526): run_code
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3466): run_ast_nodes
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3284): run_cell_async
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3079): _run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3024): run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/zmqshell.py(546): run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/ipkernel.py(422): do_execute
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(740): execute_request
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(412): dispatch_shell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(505): process_one
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(516): dispatch_queue
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/events.py(80): _run
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(1905): _run_once
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(601): run_forever
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/tornado/platform/asyncio.py(195): start
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelapp.py(736): start
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/traitlets/config/application.py(1053): launch_instance
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel_launcher.py(17): <module>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(87): _run_code
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(197): _run_module_as_main
; current tracing scope: custom-call.101; current profiling annotation: XlaModule:#hlo_module=jit_linear_solve,program_id=40#.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[1], line 22
     19 x = jnp.ones([n,m])
     20 b = jax.lax.batch_matmul(A,x[...,None])[...,0]
---> 22 x_solved = solve_fn(A,b)
     23 print(f"Average error with n ={n}, {jnp.mean(jnp.linalg.norm(x - x_solved, axis=-1))} ")
     24 print("Memory info ", device.memory_stats())

    [... skipping hidden 3 frame]

Cell In[1], line 9, in <lambda>(matrix, vector)
      5 print('on device:', device)
      7 m = 10
----> 9 batched_solve_lu = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.LU()).value)
     10 batched_solve_SVD = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.SVD()).value)
     12 solve_fns = [jax.scipy.linalg.solve, batched_solve_SVD, batched_solve_lu]

    [... skipping hidden 14 frame]

File ~/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py:1152, in ExecuteReplicated.__call__(self, *args)
   1150   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1151 else:
-> 1152   results = self.xla_executable.execute_sharded(input_bufs)
   1153 if dispatch.needs_check_special():
   1154   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.
-------
This error occurred during the runtime of your JAX program. Setting the environment
variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors.
(This can be navigated using most of the usual commands for the Python debugger:
`u` and `d` to move through stack frames, the name of a variable to print its value,
etc.) See also `[https://docs.kidger.site/equinox/api/errors/#equinox.error_if`](https://docs.kidger.site/equinox/api/errors/#equinox.error_if%60) for more
information.
@patrick-kidger
Copy link
Owner

Hmm. So JAX and Lineax both basically do the same thing for the LU/QR/Cholesky solvers, which is to use the JAX (and thus probably CUDA) implementation of those decompositions.

The fact that the QR solve is slow is expected I think -- IIRC there's no CUDA implementation of a batched QR decomposition, so vmap is handled by computing the decomposition for each batch element sequentially.

I suspect the issue is probably somewhere in the underlying CUDA (cuSolver?) implementations. I think resolving this will probably need someone to go digging through things at that level, I'm afraid.

@patrick-kidger patrick-kidger added the bug Something isn't working label Jan 29, 2024
@ma-gilles
Copy link
Author

Hi Patrick,

Thank for your answer!

I can't say I really understand how JAX/torch/cupy interact with CUDA code, but what is surprising to me is that this seems to be a bug only in JAX. Both torch/cupy seem to work, even though I would assume they use the same backend.

E.g.:

  import numpy as np
  import torch
  n = int(1e7); m = 10
  A = torch.tensor(np.repeat(np.identity(m)[None], n, axis = 0))
  L = torch.linalg.cholesky(A)
  print(torch.linalg.norm(A - L))

Outputs:

  tensor(0., dtype=torch.float64)

And the same thing for cupy, but JAX returns NaNs.

@patrick-kidger
Copy link
Owner

Oh interesting! Hmm, in that case I'm less certain of the reason. Maybe check that it's not a version issue? PyTorch and JAX tend to use different versions of the underlying NVIDIA libraries.

@ma-gilles
Copy link
Author

Thanks for the suggestion! I tried a few different versions of CUDA without changes, but updating jax seems to fix the problem, or at least it passes the few tests I have tried.

@patrick-kidger
Copy link
Owner

Curious! Well, I'm glad it's fixed. :)
Possibly an issue with a particular version of jaxlib then, if updating the version fixed things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants