jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

Metal: missing functionality #20375

Open Gattocrucco opened 8 months ago

Gattocrucco commented 8 months ago

Laundry list of things missing in the Metal backend when I tried using it:

shuhand0 commented 8 months ago

Good questions. Other than plugin compilation and runtime issues, these are platform specific JAX APIs and primitives lowering support.

leonard-gleyzer commented 1 month ago

Any update on jax.pure_callback? Thanks!