google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.03k stars 2.65k forks source link

Metal: missing functionality #20375

Open Gattocrucco opened 3 months ago

Gattocrucco commented 3 months ago

Laundry list of things missing in the Metal backend when I tried using it:

shuhand0 commented 3 months ago

Good questions. Other than plugin compilation and runtime issues, these are platform specific JAX APIs and primitives lowering support.