jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Instantiating a very large jax.Array according to a Sharding #18263

Closed wbrenton closed 1 year ago

wbrenton commented 1 year ago

I'd like to initialize an array according to a sharding, rather than initializing it on the default device and then moving it to the sharding. This is required when trying to instantiate arrays that are larger than a single GPU/TPU memory but smaller that many GPU/TPUs combined memory.

I'm building an on device replay buffer to work with Podracer style architectures.

In the cases where an algorithm requires a large replay buffer (ApeX-DQN, MuZero, Muesli) that replay buffer will need to be instantiated according to a sharding to prevent OOM errors.

There are implementations of the tooling I'm talking about brax, dejax and very recently flashbax. These implementations work well with stateless environments as you can just pmap over the training loop to shard the replay buffer across devices, effectively increasing buffer size. However this doesn't make use of the sharding tooling available through the new unified jax.Array API.

Similar to https://github.com/google/jax/issues/4221#issue-695968528.

yashk2810 commented 1 year ago

You can try this:

@jax.jit
def f():
  x = jnp.arange(8)
  return jax.lax.with_sharding_constraint(x, sharding)

f()

Another way is this:

@partial(jax.jit, out_shardings=sharding)
def f():
  return jnp.arange(8)

This will instantiate x directly on the devices as a sharded Array. In other words, x will never be on the default device.

The above 2 ways are similar but just differ in style and taste. I personally like the first way.

wbrenton commented 1 year ago

That worked like a charm thank you. Follow up question: Does donate_argnums in jit not play well with the sharding API? I can't seem to implement donate_argnums in a way where the donated buffer is useable. I keep getting:

"/admin/home-willb/cleanba2/venv39/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py:766: UserWarning: Some donated buffers were not usable: ShapedArray(uint8[5000,2,4,84,84]), ShapedArray(uint8[5000,2]), ShapedArray(float32[5000,2]), ShapedArray(uint8[5000,2]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"

The documentation still references deprecated pjit, which makes me ask.

yashk2810 commented 1 year ago

You need to set out_shardings too so then jit will donate properly.

donation works by looking at the sharded shape. If you don't specify out_shardings, we don't know what the sharding is going to be until after compilation and that's too late in the stack to set donation bits.

There is a fix for this but I just need some time to get it submitted. Until then, you can set out_shardings :)

wbrenton commented 1 year ago

hmmm I very well could be misusing but here is a minimal example that doesn't use the buffers:

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2' # Use 8 CPU devices
from functools import partial

import jax
import jax.numpy as jnp
from jax.sharding import PositionalSharding

devices = jax.devices()
sharding = PositionalSharding(devices)

@partial(jax.jit, donate_argnums=0)
def insert(rb_state, update):
    rb_state = jax.lax.dynamic_update_slice_in_dim(rb_state, update, 0, axis=0)
    return jax.lax.with_sharding_constraint(rb_state, sharding)

rb_state = jnp.zeros((100,))
rb_state = jax.device_put(rb_state, sharding)

update = jnp.ones((10,))
update = jax.device_put(update, sharding)

insert(rb_state, update)
wbrenton commented 1 year ago

I think I'm conflating jax.lax.with_sharding_constraint with explicitly passing the argument out_shardings... However the output of the jitted function is a pytree. Can I pass a pytree of shards to out_shardings?

yashk2810 commented 1 year ago

Yeah, pass in the out_shardings to jit instead of wsc. I guess that's one advantage of using out_shardings.

Can I pass a pytree of shards to out_shardings?

Yeah

I'll fix this though so this never happens again.

yashk2810 commented 1 year ago

THis works for me

In [3]: from functools import partial
   ...:
   ...: import jax
   ...: import jax.numpy as jnp
   ...: from jax.sharding import PositionalSharding
   ...:
   ...: devices = jax.devices()
   ...: sharding = PositionalSharding(devices)
   ...:
   ...: @partial(jax.jit, donate_argnums=0, out_shardings=sharding)
   ...: def insert(rb_state, update):
   ...:     rb_state = jax.lax.dynamic_update_slice_in_dim(rb_state, update, 0, axis=0)
   ...:     return rb_state
   ...:
   ...: rb_state = jnp.zeros((100,))
   ...: rb_state = jax.device_put(rb_state, sharding)
   ...:
   ...: update = jnp.ones((10,))
   ...: update = jax.device_put(update, sharding)
   ...:
   ...: insert(rb_state, update)
Out[3]:
Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)
wbrenton commented 1 year ago

Likewise, I should have closed on my last comment. Greatly appreciate your help!