jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

jax_threefry_partitionable + rematerialization doesn't seem to be working together in distributed training #17982

Open hr0nix opened 1 year ago

hr0nix commented 1 year ago

Description

I have a transformer model where each transformer block is rematerialized. The model is distributed over multiple devices using jit. Each transformer block has dropout enabled.

To prevent rng implementation from inserting synchronization operations I'm also enabling jax_threefry_partitionable as suggested in the doc.

Problem is, jax_threefry_partitionable doesn't seem to play nicely with rematerialization. As soon as I enable dropout, I get GPU OOM because JAX decides to preserve huge arrays containing rng key per activation tensor component for each transformer block, despite them being rematerialized. It should be possible for jax to reconstruct this key array from a single key during rematerialization, but it doesn't seem to do that.

I'm happy to provide a repoduction if you can confirm that this is unexpected behavior. If not, can you please suggest a workaround? Currently it doesn't seem possible to efficiently train large models with dropout.

A relevant discussion with OOM error message example here: https://github.com/google/flax/discussions/3090

What jax/jaxlib version are you using?

0.4.14

Which accelerator(s) are you using?

GPU

Additional system info

python3.10

NVIDIA GPU info

Fri Oct  6 14:22:06 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8A:00.0 Off |                    0 |
| N/A   35C    P0              72W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:8B:00.0 Off |                    0 |
| N/A   31C    P0              71W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:8C:00.0 Off |                    0 |
| N/A   31C    P0              72W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   35C    P0              74W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:9C:00.0 Off |                    0 |
| N/A   36C    P0              76W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:9D:00.0 Off |                    0 |
| N/A   31C    P0              72W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:9E:00.0 Off |                    0 |
| N/A   31C    P0              71W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:9F:00.0 Off |                    0 |
| N/A   35C    P0              73W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
Karina1997 commented 1 year ago

I strugle with the same issue.

minotru commented 1 year ago

Also relevant for me, would be great to have it solved.

hr0nix commented 1 year ago

@jakevdp Hey, sorry for mentioning you directly, but this issue hasn't received any attention for several weeks. Can someone from the jax team please take a look? Thanks!

hr0nix commented 1 year ago

@froystig Hey, sorry for mentioning you directly, but can someone take a look at this issue? It's a big blocker for me.

hr0nix commented 1 year ago

I've made a repro for this bug. Turns out it has nothing to do with jax_threefry_partitionable, perfectly repoducible without it.

Repo was made for A100 80Gb, so tensor shapes might need to be adjusted for a GPU with a different amount of memory.

./repro.py — will fail because without rematerialization it needs ~122.75 Gb of GPU RAM ./repro.py --remat — works perfectly fine with remat, because it now needs just 63Gb of GPU RAM ./repro.py --remat --dropout-rate 0.1 — OOMs again, requiring ~118Gb of GPU RAM. From looking at peak buffers it becomes clear that the dropout mask is not being rematerialized: tensors correponding to full dropout masks for different layers are occupying memory.

Peak buffers:
        Buffer 1:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_5._apply_block/Block_5/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================
        Buffer 2:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_4._apply_block/Block_4/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================

        Buffer 3:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_4._apply_block/Block_4/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================

        Buffer 4:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_3._apply_block/Block_3/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================

        Buffer 5:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_3._apply_block/Block_3/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================
...

Repro code:

import functools

import click
import flax
import flax.linen as nn
import flax.training.train_state
import jax
import jax.numpy as jnp
import optax

class Dropout(nn.Module):
    rate: float

    @nn.compact
    def __call__(self, inputs, rng):
        if self.rate == 0.0:
            return inputs

        if self.rate == 1.0:
            return jnp.zeros_like(inputs)

        keep_prob = 1.0 - self.rate
        mask = jax.random.bernoulli(rng, p=keep_prob, shape=inputs.shape)
        return jax.lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))

class Block(nn.Module):
    dim: int
    dropout_rate: float

    @nn.compact
    def __call__(self, input, rng):
        scale = 32  # We want large memory consumption without remat
        emb = nn.Dense(features=self.dim * scale)(input)
        emb = nn.relu(emb)
        emb = Dropout(rate=self.dropout_rate)(emb, rng)
        emb = nn.Dense(features=self.dim)(emb)
        return emb

class Model(nn.Module):
    dim: int
    dropout_rate: float
    num_layers: int
    remat: bool

    @nn.compact
    def __call__(self, input, rng):
        def _apply_block(block, block_input, rng):
            return block(block_input, rng)

        if self.remat:
            _apply_block = nn.checkpoint(
                _apply_block,
                policy=jax.checkpoint_policies.nothing_saveable,
                prevent_cse=True,
            )

        emb = input
        for _ in range(self.num_layers):
            rng, block_rng = jax.random.split(rng)
            block = Block(dim=self.dim, dropout_rate=self.dropout_rate)
            emb = _apply_block(block, emb, block_rng)

        return emb

def loss_fn(params, train_state, batch, rng):
    outputs = train_state.apply_fn(params, batch, rng)
    return jnp.mean(outputs * outputs)

@functools.partial(jax.jit, donate_argnames=("train_state",))
def train_step(train_state, batch, rng):
    grad_fn = jax.grad(loss_fn)
    grad = grad_fn(train_state.params, train_state, batch, rng)
    train_state = train_state.apply_gradients(grads=grad)
    return train_state

def make_batch(batch_size, dim):
    return jnp.zeros(shape=(batch_size, dim), dtype=jnp.float32)

@click.command()
@click.option("--dim", type=int, default=1024)
@click.option("--batch-size", type=int, default=8192)
@click.option("--dropout-rate", type=float, default=0.0)
@click.option("--num-layers", type=int, default=64)
@click.option("--remat", is_flag=True)
def main(
    dim: int,
    batch_size: int,
    dropout_rate: float,
    num_layers: int,
    remat: bool,
):
    model = Model(
        dim=dim, dropout_rate=dropout_rate, num_layers=num_layers, remat=remat
    )
    batch = make_batch(batch_size=batch_size, dim=dim)
    rng = jax.random.PRNGKey(0)
    params = model.init({"params": rng}, batch, rng)
    optimizer = optax.adam(learning_rate=1e-3)
    train_state = flax.training.train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=optimizer
    )
    train_state = train_step(train_state, batch, rng)

if __name__ == "__main__":
    main()
froystig commented 1 year ago

Thanks for the repro!

mattjj commented 1 year ago

@hr0nix sorry that this slipped through the cracks. Thanks for the pings, everyone.

Can you check that this repros with jaxlib 0.4.20? IIRC there was one GPU-specific remat fix that happened recently, though I don't have a link to it at the moment. EDIT: https://github.com/openxla/xla/pull/6527

hr0nix commented 1 year ago

Thanks for the pointer!

Unfortunately, it looks like the problem is still present with jaxlib==0.4.20

mattjj commented 1 year ago

Thanks for checking.

I think our next step is to try to repro on TPU, to see if it's GPU-specific. We can do that on our end.

hr0nix commented 1 year ago

Hey, any updates on this?

hr0nix commented 11 months ago

@mattjj @froystig Happy new year, gentlemen! Do you think 2024 is the year when this bug finally got fixed? ;-)

hr0nix commented 9 months ago

Ping!

hr0nix commented 9 months ago

Hmm, looks like using jax_default_prng_impl=rbg fixes this issue.

froystig commented 9 months ago

Hmm, looks like using jax_default_prng_impl=rbg fixes this issue.

Thanks, this is a useful additional bit of info. This is still in our queue, but we haven't dug in yet.

I understood your most recent comment to mean that you have a workaround. Is that right? At large scales, jax_default_prng_impl=rbg can be a good idea to try anyway, as it can drastically speed up compilation times.

hr0nix commented 8 months ago

I understood your most recent comment to mean that you have a workaround. Is that right?

Looks like it. Interestingly, it also seems to fix another rng-related issue: https://github.com/google/jax/issues/19893

Btw, can you elaborate a bit on how does the rng implementation work when keys are sharded? E.g. does it require any additional communication?

froystig commented 8 months ago

On GPU, for a fixed key, I do not expect that sharded number generation under rbg would require communication. E.g. I expect the following to print False:

import jax
import jax.numpy as jnp

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24.), x_sharding)

f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())

(and the same if we check for other collectives in the HLO.)

Meanwhile I also expect the output sharding of f(key, x) to be, e.g.:

PositionalSharding([{GPU 0} {GPU 1}], shape=(2,))

when jax.devices() is a list of two GPUs.

Your comment however asks "when keys are sharded." Do you mean that you are sharding a computation that vmaps a random number generation operation over a batch of keys (in the form of a sharded key array)? If so, then there's a current unrelated issue to watch specifically regarding vmap of rbg over keys, covered by #19085. The workaround there is not to vmap number generation over keys, but instead to hoist the generation step: draw the entire batch of random numbers from a single key outside of the vmapped function, and pass that in.