-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Description
Description
After building jaxlib as per the instructions and installing jax-metal, upon testing with an existing model which works fine using CPU (and GPU on linux), I get the following error.
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky'
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32>
The full error message is very low, and is attached here.
I did try a minimal example shown below which also calls the cholesky operator, but I couldn't reproduce the same error. I am more than happy to try another more in-depth test code. Any suggestions?
from jax import jit
import jax.numpy as jnp
import jax.random as jnr
import jax.scipy as jsp
key = jnr.PRNGKey(0)
A = jnr.normal(key, (100,100))
def calc_cholesky_decomp(test_matrix):
psd_test_matrix = test_matrix @ test_matrix.T
col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
return col_decomp
calc_cholesky_decomp(A)
jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
jitted_calc_cholesky_decomp(A)
What jax/jaxlib version are you using?
jaxlib 0.4.10 (metal), jax 0.4.11
Which accelerator(s) are you using?
CPU/GPU
Additional system info
Python v3.10.10, Apple M2
NVIDIA GPU info
No response