Open hr0nix opened 1 year ago
I strugle with the same issue.
Also relevant for me, would be great to have it solved.
@jakevdp Hey, sorry for mentioning you directly, but this issue hasn't received any attention for several weeks. Can someone from the jax team please take a look? Thanks!
@froystig Hey, sorry for mentioning you directly, but can someone take a look at this issue? It's a big blocker for me.
I've made a repro for this bug. Turns out it has nothing to do with jax_threefry_partitionable
, perfectly repoducible without it.
Repo was made for A100 80Gb, so tensor shapes might need to be adjusted for a GPU with a different amount of memory.
./repro.py
— will fail because without rematerialization it needs ~122.75 Gb of GPU RAM
./repro.py --remat
— works perfectly fine with remat, because it now needs just 63Gb of GPU RAM
./repro.py --remat --dropout-rate 0.1
— OOMs again, requiring ~118Gb of GPU RAM. From looking at peak buffers it becomes clear that the dropout mask is not being rematerialized: tensors correponding to full dropout masks for different layers are occupying memory.
Peak buffers:
Buffer 1:
Size: 512.00MiB
Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_5._apply_block/Block_5/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
XLA Label: custom-call
Shape: u32[134217728]
==========================
Buffer 2:
Size: 512.00MiB
Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_4._apply_block/Block_4/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
XLA Label: custom-call
Shape: u32[134217728]
==========================
Buffer 3:
Size: 512.00MiB
Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_4._apply_block/Block_4/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
XLA Label: custom-call
Shape: u32[134217728]
==========================
Buffer 4:
Size: 512.00MiB
Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_3._apply_block/Block_3/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
XLA Label: custom-call
Shape: u32[134217728]
==========================
Buffer 5:
Size: 512.00MiB
Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_3._apply_block/Block_3/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
XLA Label: custom-call
Shape: u32[134217728]
==========================
...
Repro code:
import functools
import click
import flax
import flax.linen as nn
import flax.training.train_state
import jax
import jax.numpy as jnp
import optax
class Dropout(nn.Module):
rate: float
@nn.compact
def __call__(self, inputs, rng):
if self.rate == 0.0:
return inputs
if self.rate == 1.0:
return jnp.zeros_like(inputs)
keep_prob = 1.0 - self.rate
mask = jax.random.bernoulli(rng, p=keep_prob, shape=inputs.shape)
return jax.lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
class Block(nn.Module):
dim: int
dropout_rate: float
@nn.compact
def __call__(self, input, rng):
scale = 32 # We want large memory consumption without remat
emb = nn.Dense(features=self.dim * scale)(input)
emb = nn.relu(emb)
emb = Dropout(rate=self.dropout_rate)(emb, rng)
emb = nn.Dense(features=self.dim)(emb)
return emb
class Model(nn.Module):
dim: int
dropout_rate: float
num_layers: int
remat: bool
@nn.compact
def __call__(self, input, rng):
def _apply_block(block, block_input, rng):
return block(block_input, rng)
if self.remat:
_apply_block = nn.checkpoint(
_apply_block,
policy=jax.checkpoint_policies.nothing_saveable,
prevent_cse=True,
)
emb = input
for _ in range(self.num_layers):
rng, block_rng = jax.random.split(rng)
block = Block(dim=self.dim, dropout_rate=self.dropout_rate)
emb = _apply_block(block, emb, block_rng)
return emb
def loss_fn(params, train_state, batch, rng):
outputs = train_state.apply_fn(params, batch, rng)
return jnp.mean(outputs * outputs)
@functools.partial(jax.jit, donate_argnames=("train_state",))
def train_step(train_state, batch, rng):
grad_fn = jax.grad(loss_fn)
grad = grad_fn(train_state.params, train_state, batch, rng)
train_state = train_state.apply_gradients(grads=grad)
return train_state
def make_batch(batch_size, dim):
return jnp.zeros(shape=(batch_size, dim), dtype=jnp.float32)
@click.command()
@click.option("--dim", type=int, default=1024)
@click.option("--batch-size", type=int, default=8192)
@click.option("--dropout-rate", type=float, default=0.0)
@click.option("--num-layers", type=int, default=64)
@click.option("--remat", is_flag=True)
def main(
dim: int,
batch_size: int,
dropout_rate: float,
num_layers: int,
remat: bool,
):
model = Model(
dim=dim, dropout_rate=dropout_rate, num_layers=num_layers, remat=remat
)
batch = make_batch(batch_size=batch_size, dim=dim)
rng = jax.random.PRNGKey(0)
params = model.init({"params": rng}, batch, rng)
optimizer = optax.adam(learning_rate=1e-3)
train_state = flax.training.train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer
)
train_state = train_step(train_state, batch, rng)
if __name__ == "__main__":
main()
Thanks for the repro!
@hr0nix sorry that this slipped through the cracks. Thanks for the pings, everyone.
Can you check that this repros with jaxlib 0.4.20? IIRC there was one GPU-specific remat fix that happened recently, though I don't have a link to it at the moment. EDIT: https://github.com/openxla/xla/pull/6527
Thanks for the pointer!
Unfortunately, it looks like the problem is still present with jaxlib==0.4.20
Thanks for checking.
I think our next step is to try to repro on TPU, to see if it's GPU-specific. We can do that on our end.
Hey, any updates on this?
@mattjj @froystig Happy new year, gentlemen! Do you think 2024 is the year when this bug finally got fixed? ;-)
Ping!
Hmm, looks like using jax_default_prng_impl=rbg
fixes this issue.
Hmm, looks like using
jax_default_prng_impl=rbg
fixes this issue.
Thanks, this is a useful additional bit of info. This is still in our queue, but we haven't dug in yet.
I understood your most recent comment to mean that you have a workaround. Is that right? At large scales, jax_default_prng_impl=rbg
can be a good idea to try anyway, as it can drastically speed up compilation times.
I understood your most recent comment to mean that you have a workaround. Is that right?
Looks like it. Interestingly, it also seems to fix another rng-related issue: https://github.com/google/jax/issues/19893
Btw, can you elaborate a bit on how does the rng
implementation work when keys are sharded? E.g. does it require any additional communication?
On GPU, for a fixed key, I do not expect that sharded number generation under rbg
would require communication. E.g. I expect the following to print False
:
import jax
import jax.numpy as jnp
@jax.jit
def f(key, x):
numbers = jax.random.uniform(key, x.shape)
return x + numbers
key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24.), x_sharding)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
(and the same if we check for other collectives in the HLO.)
Meanwhile I also expect the output sharding of f(key, x)
to be, e.g.:
PositionalSharding([{GPU 0} {GPU 1}], shape=(2,))
when jax.devices()
is a list of two GPUs.
Your comment however asks "when keys are sharded." Do you mean that you are sharding a computation that vmaps a random number generation operation over a batch of keys (in the form of a sharded key array)? If so, then there's a current unrelated issue to watch specifically regarding vmap
of rbg
over keys, covered by #19085. The workaround there is not to vmap
number generation over keys, but instead to hoist the generation step: draw the entire batch of random numbers from a single key outside of the vmapped function, and pass that in.
Description
I have a transformer model where each transformer block is rematerialized. The model is distributed over multiple devices using jit. Each transformer block has dropout enabled.
To prevent rng implementation from inserting synchronization operations I'm also enabling
jax_threefry_partitionable
as suggested in the doc.Problem is,
jax_threefry_partitionable
doesn't seem to play nicely with rematerialization. As soon as I enable dropout, I get GPU OOM because JAX decides to preserve huge arrays containing rng key per activation tensor component for each transformer block, despite them being rematerialized. It should be possible for jax to reconstruct this key array from a single key during rematerialization, but it doesn't seem to do that.I'm happy to provide a repoduction if you can confirm that this is unexpected behavior. If not, can you please suggest a workaround? Currently it doesn't seem possible to efficiently train large models with dropout.
A relevant discussion with OOM error message example here: https://github.com/google/flax/discussions/3090
What jax/jaxlib version are you using?
0.4.14
Which accelerator(s) are you using?
GPU
Additional system info
python3.10
NVIDIA GPU info