jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.33k stars 2.78k forks source link

Apple Silicon: error: failed to legalize operation 'mhlo.triangular_solve' #17490

Open yu-fz opened 1 year ago

yu-fz commented 1 year ago

Description

Matrix inversion appears to be broken on jax-metal. Apologies in advance if this is not the right place to report the issue.

Repro:

import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'METAL')

A = jnp.array([[1, 2, 3],
               [4, 5, 6],
               [7, 8, 9]])
B = jnp.linalg.inv(A)
print(B)

What jax/jaxlib version are you using?

jax 0.4.11, jaxlib 0.4.10

Which accelerator(s) are you using?

GPU/Metal

Additional system info

Mac OS 13.5.1, M1 Max

NVIDIA GPU info

No response

p-i- commented 1 year ago
Would be a +1 for optics if assignees could 👍 to acknowledge receipt of the task. There's a handful of bugs around core operations with no indication of progress nor what's blocking them. Guessing it's Apple dragging their heels; do they not have the firepower to provide (uniquely) a software-stack to support their own hardware that's 1+ years old now? Seems cringe that the feeling of UX disappointment should fall unfairly upon JAX.
benjaminvatterj commented 10 months ago

Just to add this and a plethora of other fundamental jax operations simply don't work on Jax-metal with no update insight. At this point, jax-metal is quite useless as far as I have been able to test it. As of version Jax-metal 0.0.4, I would say simply avoid it and install Jax on CPU.

benjaminvatterj commented 9 months ago

Just FYI, it seems like Apple released a new version of metal-jax a few weeks back: version 0.0.5. This basic feature is still broken.

crondonm commented 8 months ago

Same issue here. Any idea when a fix will be released?

nngabe commented 7 months ago

I will ping this and add that I hope this gets fixed in the next release.

I think the JAX + Apple Silicon combo offers some unique advantages for prototyping models locally (i.e. smaller compute but larger RAM compared to a comparable NVIDIA workstation). I think JAX users will pick up on this pretty quickly once features like this are fixed.

shuhand0 commented 7 months ago

We are looking into adding the conversion for the op.

mvanaltvorst commented 6 months ago

I also hope this gets fixed soon! Do you happen to have any updates @shuhand0? Thanks

TheSkyentist commented 1 week ago

As of jax-metal 0.1.1 this issue still persists, at least in my configuration.

MacOS 15.0.1 Apple M2 Pro

yu-fz commented 1 week ago

As of jax-metal 0.1.1 this issue still persists, at least in my configuration.

MacOS 15.0.1 Apple M2 Pro

😭 😭 😭 bruh