google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.18k stars 2.67k forks source link

Buffer donation for Metal plugin #20212

Open dlwh opened 4 months ago

dlwh commented 4 months ago

Buffer donation would be nice. I don't see an issue for it so just opening it for tracking/asking if it's on the Apple JAX Metal Team's roadmap

>>> import jax
>>> jax.devices()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-12 16:07:16.498439: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

[METAL(id=0)]
>>> import jax.numpy as jnp
>>> x = jax.jit(lambda x: x)(jnp.zeros((4, 5))
...
...
...
... )
>>> x = jax.jit(lambda x: x)(jnp.zeros((4, 5)))
>>> x = jax.jit(lambda x: x, donate_args=True)(jnp.zeros((4, 5)))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: jit() got an unexpected keyword argument 'donate_args'
>>> x = jax.jit(lambda x: x, donate_argnums=(0,))(jnp.zeros((4, 5)))
/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[4,5]).
Donation is not implemented for ('METAL',).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"
>>>
shuhand0 commented 4 months ago

Thanks for requesting the feature, and we will track it and update here when we have a plan.