Skip to content

Apple Silicon: error: failed to legalize operation 'mhlo.cholesky' #16321

@adam-hartshorne

Description

@adam-hartshorne

Description

After building jaxlib as per the instructions and installing jax-metal, upon testing with an existing model which works fine using CPU (and GPU on linux), I get the following error.

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky'
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32>

The full error message is very low, and is attached here.

cholesky_full_error.txt.zip

I did try a minimal example shown below which also calls the cholesky operator, but I couldn't reproduce the same error. I am more than happy to try another more in-depth test code. Any suggestions?

from jax import jit
import jax.numpy as jnp
import jax.random as jnr
import jax.scipy as jsp

key = jnr.PRNGKey(0)
A = jnr.normal(key, (100,100))

def calc_cholesky_decomp(test_matrix):
    psd_test_matrix = test_matrix @ test_matrix.T
    col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
    return col_decomp

calc_cholesky_decomp(A)

jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
jitted_calc_cholesky_decomp(A)

What jax/jaxlib version are you using?

jaxlib 0.4.10 (metal), jax 0.4.11

Which accelerator(s) are you using?

CPU/GPU

Additional system info

Python v3.10.10, Apple M2

NVIDIA GPU info

No response

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions