Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear solvers (jax.numpy.linalg.solve, jax.scipy.linalg.solve) bugged when allocating large memory #19431

Closed
ma-gilles opened this issue Jan 19, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@ma-gilles
Copy link

ma-gilles commented Jan 19, 2024

Description

Hi JAX team,

I have noticed that there seems to be a bug in the linear solvers which seems to happen when taking up a significant part of GPU memory.

Below is an example. I solve well conditioned linear systems ( A = identity, b = all ones) with A are 10x10 matrices.
In one test the number of matrices n is 10^6 (so the matrices A take a total of 0.4 GB), and in the other one the number of matrices is 10^7 (for a total memory of 4GB). Note that 63 GB are reserved by JAX on the GPU so this is well below the capacity, and the reported "peak" usage is 13 GB.

This happens both with jax.numpy.linalg.solve and jax.scipy.linalg.solve

import jax
import jax.numpy as jnp
from jax import random
device = jax.local_devices()[0]
print('on device:', device)
solve_fns = [ jax.scipy.linalg.solve, jnp.linalg.solve]
m = 10
for solve_fn in solve_fns:
    for n in [ int(1e6), int(1e7)]:
        A = jnp.repeat(jnp.identity(m)[None], n, axis = 0)
        x = jnp.ones([n,m])
        
        b = jax.lax.batch_matmul(A,x[...,None])[...,0]
        x_solved = solve_fn(A,b)
        print(f"Average error with n ={n}, {jnp.mean(jnp.linalg.norm(x - x_solved, axis=-1))} ")
        print("Memory info ", device.memory_stats())

The output is as follows:

on device: cuda:0
Average error with n =1000000, 0.0 
Memory info  {'bytes_in_use': 520001024, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 804000512, 'largest_free_block_bytes': 0, 'num_allocs': 42, 'peak_bytes_in_use': 1324001536, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =10000000, nan 
Memory info  {'bytes_in_use': 5200000000, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 71, 'peak_bytes_in_use': 13280000512, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =1000000, 0.0 
Memory info  {'bytes_in_use': 520000000, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 100, 'peak_bytes_in_use': 13280000512, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =10000000, nan 
Memory info  {'bytes_in_use': 5200000256, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 129, 'peak_bytes_in_use': 13280000512, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}

This same code works fine on CPU, and I have also tried on a different GPU with the same results.

Thank you very much for all your great work!

What jax/jaxlib version are you using?

jax v0.4.18; jaxlib v0.4.18

Which accelerator(s) are you using?

GPU/CPU

Additional system info?

Python 3.9.18, OS Linux

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-80GB On | 00000000:01:00.0 Off | 0 |
| N/A 37C P0 62W / 500W | 2MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+

@ma-gilles ma-gilles added the bug Something isn't working label Jan 19, 2024
@ma-gilles ma-gilles changed the title Linear solvers (jax.linalg.solve, jax.scipy.solve) bugged when allocating large memory Linear solvers (jax.numpy.linalg.solve, jax.scipy.linalg.solve) bugged when allocating large memory Jan 19, 2024
@ma-gilles
Copy link
Author

ma-gilles commented Jan 23, 2024

I have been making some tests about different batched solvers and decompositions, see below:

The Cholesky (and LU probably), seems to be the main issue here. I tried to swap to a QR solver, but it is somehow 300 times slower than Cholesky for large batch size, and ~100 slower than SVD. SVD seems to be the most reliable, though I have seen it fail too.

import jax
import jax.numpy as jnp
device = jax.local_devices()[0]
print('on device:', device)

m = 10

import time
for n in [1e5, 1e6, 1e7]:
    n = int(n)
    A = jnp.repeat(jnp.identity(m)[None], n, axis = 0).block_until_ready()
    print("n=", n)
    
    st_time = time.time()
    U,S,Vh = jax.scipy.linalg.svd(A)
    A2 = jax.lax.batch_matmul(U * S[...,None,:], Vh)
    print(f"SVD error {jnp.mean(jnp.linalg.norm(A2 - A, axis=(-1, -2))/jnp.linalg.norm(A, axis=(-1,-2)))}, time = {time.time() - st_time} ")

    st_time = time.time()
    L = jax.scipy.linalg.cholesky(A)
    A2 = jax.lax.batch_matmul(L, L.swapaxes(-1,-2)).block_until_ready()
    print(f"Cholesky error {jnp.mean(jnp.linalg.norm(A2 - A, axis=(-1, -2))/jnp.linalg.norm(A, axis=(-1,-2)))}, time = {time.time() - st_time} ")

    
    if n <= 1e6:
        st_time = time.time()
        Q,R = jnp.linalg.qr(A)
        A2 = jax.lax.batch_matmul(Q,R).block_until_ready()
        print(f"QR error {jnp.mean(jnp.linalg.norm(A2 - A, axis=(-1, -2))/jnp.linalg.norm(A, axis=(-1,-2)))}, time = {time.time() - st_time} ")

Output:

on device: cuda:0
n= 100000
SVD error 0.0, time = 0.6100842952728271 
Cholesky error 0.0, time = 0.15522980690002441 
QR error 0.0, time = 3.7535462379455566 
n= 1000000
SVD error 0.0, time = 0.5560173988342285 
Cholesky error 0.0, time = 0.1310713291168213 
QR error 0.0, time = 35.838584184646606 
n= 10000000
SVD error 0.0, time = 2.056659460067749 
Cholesky error nan, time = 0.27480244636535645 

Note that JAX doesn't throw an error or warning

Here is a SVD based solver to use in the mean time, in case someone else needs a stopgap:

def solve_by_SVD(A,b):
    U,S,Vh = jax.scipy.linalg.svd(A)

    if b.ndim == A.ndim -1:
        expand = True
        b = b[...,None]
    else:
        expand = False
    
    Uhb = jax.lax.batch_matmul(jnp.conj(U.swapaxes(-1,-2)),b)/ S[...,None]
    x = jax.lax.batch_matmul(jnp.conj(Vh.swapaxes(-1,-2)),Uhb)

    if expand:
        x = x[...,0]

    return x

@ma-gilles
Copy link
Author

As far as I can tell, the LU/Cholesky bugs are fixed by updating jax to 0.4.23 . See patrick-kidger/lineax#79 (comment) for more info

@hawkinsp
Copy link
Collaborator

hawkinsp commented Apr 4, 2024

Well, fixed is fixed, I guess. We could dig into why, but it would mostly be of historical interest. Please reopen if it happens again!

@hawkinsp hawkinsp closed this as completed Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants