-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Linear solvers (jax.numpy.linalg.solve, jax.scipy.linalg.solve) bugged when allocating large memory #19431
Comments
I have been making some tests about different batched solvers and decompositions, see below: The Cholesky (and LU probably), seems to be the main issue here. I tried to swap to a QR solver, but it is somehow 300 times slower than Cholesky for large batch size, and ~100 slower than SVD. SVD seems to be the most reliable, though I have seen it fail too.
Output:
Note that JAX doesn't throw an error or warning Here is a SVD based solver to use in the mean time, in case someone else needs a stopgap:
|
As far as I can tell, the LU/Cholesky bugs are fixed by updating jax to 0.4.23 . See patrick-kidger/lineax#79 (comment) for more info |
Well, fixed is fixed, I guess. We could dig into why, but it would mostly be of historical interest. Please reopen if it happens again! |
Description
Hi JAX team,
I have noticed that there seems to be a bug in the linear solvers which seems to happen when taking up a significant part of GPU memory.
Below is an example. I solve well conditioned linear systems ( A = identity, b = all ones) with A are 10x10 matrices.
In one test the number of matrices n is 10^6 (so the matrices A take a total of 0.4 GB), and in the other one the number of matrices is 10^7 (for a total memory of 4GB). Note that 63 GB are reserved by JAX on the GPU so this is well below the capacity, and the reported "peak" usage is 13 GB.
This happens both with jax.numpy.linalg.solve and jax.scipy.linalg.solve
The output is as follows:
This same code works fine on CPU, and I have also tried on a different GPU with the same results.
Thank you very much for all your great work!
What jax/jaxlib version are you using?
jax v0.4.18; jaxlib v0.4.18
Which accelerator(s) are you using?
GPU/CPU
Additional system info?
Python 3.9.18, OS Linux
NVIDIA GPU info
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-80GB On | 00000000:01:00.0 Off | 0 |
| N/A 37C P0 62W / 500W | 2MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered: