Closed wbrenton closed 1 year ago
You can try this:
@jax.jit
def f():
x = jnp.arange(8)
return jax.lax.with_sharding_constraint(x, sharding)
f()
Another way is this:
@partial(jax.jit, out_shardings=sharding)
def f():
return jnp.arange(8)
This will instantiate x
directly on the devices as a sharded Array. In other words, x
will never be on the default device.
The above 2 ways are similar but just differ in style and taste. I personally like the first way.
That worked like a charm thank you. Follow up question: Does donate_argnums in jit not play well with the sharding API? I can't seem to implement donate_argnums in a way where the donated buffer is useable. I keep getting:
"/admin/home-willb/cleanba2/venv39/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py:766: UserWarning: Some donated buffers were not usable: ShapedArray(uint8[5000,2,4,84,84]), ShapedArray(uint8[5000,2]), ShapedArray(float32[5000,2]), ShapedArray(uint8[5000,2]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
The documentation still references deprecated pjit, which makes me ask.
You need to set out_shardings too so then jit will donate properly.
donation works by looking at the sharded shape. If you don't specify out_shardings, we don't know what the sharding is going to be until after compilation and that's too late in the stack to set donation bits.
There is a fix for this but I just need some time to get it submitted. Until then, you can set out_shardings :)
hmmm I very well could be misusing but here is a minimal example that doesn't use the buffers:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2' # Use 8 CPU devices
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import PositionalSharding
devices = jax.devices()
sharding = PositionalSharding(devices)
@partial(jax.jit, donate_argnums=0)
def insert(rb_state, update):
rb_state = jax.lax.dynamic_update_slice_in_dim(rb_state, update, 0, axis=0)
return jax.lax.with_sharding_constraint(rb_state, sharding)
rb_state = jnp.zeros((100,))
rb_state = jax.device_put(rb_state, sharding)
update = jnp.ones((10,))
update = jax.device_put(update, sharding)
insert(rb_state, update)
I think I'm conflating jax.lax.with_sharding_constraint
with explicitly passing the argument out_shardings
... However the output of the jitted function is a pytree. Can I pass a pytree of shards to out_shardings
?
Yeah, pass in the out_shardings to jit instead of wsc
. I guess that's one advantage of using out_shardings.
Can I pass a pytree of shards to out_shardings?
Yeah
I'll fix this though so this never happens again.
THis works for me
In [3]: from functools import partial
...:
...: import jax
...: import jax.numpy as jnp
...: from jax.sharding import PositionalSharding
...:
...: devices = jax.devices()
...: sharding = PositionalSharding(devices)
...:
...: @partial(jax.jit, donate_argnums=0, out_shardings=sharding)
...: def insert(rb_state, update):
...: rb_state = jax.lax.dynamic_update_slice_in_dim(rb_state, update, 0, axis=0)
...: return rb_state
...:
...: rb_state = jnp.zeros((100,))
...: rb_state = jax.device_put(rb_state, sharding)
...:
...: update = jnp.ones((10,))
...: update = jax.device_put(update, sharding)
...:
...: insert(rb_state, update)
Out[3]:
Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
Likewise, I should have closed on my last comment. Greatly appreciate your help!
I'd like to initialize an array according to a sharding, rather than initializing it on the default device and then moving it to the sharding. This is required when trying to instantiate arrays that are larger than a single GPU/TPU memory but smaller that many GPU/TPUs combined memory.
I'm building an on device replay buffer to work with Podracer style architectures.
In the cases where an algorithm requires a large replay buffer (ApeX-DQN, MuZero, Muesli) that replay buffer will need to be instantiated according to a sharding to prevent OOM errors.
There are implementations of the tooling I'm talking about brax, dejax and very recently flashbax. These implementations work well with stateless environments as you can just pmap over the training loop to shard the replay buffer across devices, effectively increasing buffer size. However this doesn't make use of the sharding tooling available through the new unified jax.Array API.
Similar to https://github.com/google/jax/issues/4221#issue-695968528.