google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.87k stars 2.73k forks source link

Apple Silicon: error: failed to legalize operation 'mhlo.cholesky' #16321

Open adam-hartshorne opened 1 year ago

adam-hartshorne commented 1 year ago

Description

After building jaxlib as per the instructions and installing jax-metal, upon testing with an existing model which works fine using CPU (and GPU on linux), I get the following error.

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky' /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32>

The full error message is very low, and is attached here.

cholesky_full_error.txt.zip

I did try a minimal example shown below which also calls the cholesky operator, but I couldn't reproduce the same error. I am more than happy to try another more in-depth test code. Any suggestions?

from jax import jit
import jax.numpy as jnp
import jax.random as jnr
import jax.scipy as jsp

key = jnr.PRNGKey(0)
A = jnr.normal(key, (100,100))

def calc_cholesky_decomp(test_matrix):
    psd_test_matrix = test_matrix @ test_matrix.T
    col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
    return col_decomp

calc_cholesky_decomp(A)

jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
jitted_calc_cholesky_decomp(A)

What jax/jaxlib version are you using?

jaxlib 0.4.10 (metal), jax 0.4.11

Which accelerator(s) are you using?

CPU/GPU

Additional system info

Python v3.10.10, Apple M2

NVIDIA GPU info

No response

hawkinsp commented 1 year ago

@shuhand0 @kulinseth

benjaminvatterj commented 8 months ago

Can confirm that this is still broken in version 0.0.5

c0g commented 6 months ago

Any update on ETA here? I am trying to use Brax on Metal and it wants the cholesky decomp.

@kulinseth

shuhand0 commented 6 months ago

Looking into add the conversion of the op.

benjaminvatterj commented 6 months ago

I just wanted to mark that it's still not implemented in version 0.0.6 in case anyone noticed the new release

vhaasteren commented 5 months ago

I'm also eagerly awaiting this

mvanaltvorst commented 5 months ago

Would love to use multivariate normal distributions which depends on the Cholesky decomposition. Am eagerly awaiting this.

driesmarzougui commented 4 months ago

Still not working in jax-metal v0.0.7

benjaminvatterj commented 4 months ago

We're approaching the one year mark on this. Any hope that this would be resolved soon?

c0g commented 4 months ago

Is jax-metal open source? I can’t find it but would consider contributing.

benjaminvatterj commented 4 months ago

As far as I know its maintained by people at Apple (@kulinseth). I believe they don't share their code.

vhaasteren commented 3 months ago

I can report that v0.1.0 still does not address this

yangfengzzz commented 3 months ago

I found WWDC24 show Jax support Mujoco, But I try MJX, it will still cause this issue problem.

截屏2024-06-15 06 51 07
yangfengzzz commented 2 months ago

After looking at the code, I found that cholesky is defined in Jaxlib. It seems that inserting the metal backend through pjrt cannot solve this problem?

mmattamala commented 1 month ago

Problem persist with jax-metal v0.1.0, jax v0.4.31 and jaxlib v0.4.31