jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.53k stars 2.8k forks source link

RNG slows down data parallel training #15895

Open froystig opened 1 year ago

froystig commented 1 year ago

Discussed in https://github.com/google/jax/discussions/15783

Originally posted by **jjyyxx** April 27, 2023 I was working with a transformer model in jax and haiku, and found that dropout greatly slows down data parallel training, the main training step looks like ```python self._train_state, batch_scalars = self._train_step(train_key, self._train_state, batch) ``` where - Sharding is created with `sharding = jax.sharding.PositionalSharding(jax.devices())`, containing GPU:0 and GPU:1 - `train_key` is a `PRNGKeyArray`, not sharded - `self._train_state` is a PyTree of params and opt_states, replicated with `jax.device_put(train_state, sharding.replicate())` - `batch` is a PyTree of data and labels, sharded with `jax.device_put(batch, sharding)` Every operation in this model (except final loss reduction) is independent between each sample in batch, so this should be trivially data parallel. Without `x = hk.dropout(hk.next_rng_key(), self.dropout, x)` (boils down to a `jax.random.split` and a `jax.random.bernoulli`), every thing works well (Single device: 4.2 it/s, Two devices: 7.5 it/s). But when dropout is enabled (called 20 times), I got - Single device: 3.75 it/s - Two devices: 1.9 it/s - Two devices with `jax.config.update('jax_threefry_partitionable', True)`: 5.32 it/s (I was aware of the [document](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers)) which is far from expected. Did I miss somthing? Could this performance be optimized?
froystig commented 1 year ago

@jjyyxx, do you have a minimal code example that reproduces this?

jjyyxx commented 1 year ago

@froystig I could give it a try. But can a snippet containing haiku be considered minimal? I suspect at least a medium-sized model could reveal the difference. But I will try it first anyway.

jjyyxx commented 1 year ago
"""
conda create -p .conda/jax python=3.9 numpy scipy ipykernel
conda activate .conda/jax
pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install optax dm-haiku

# Run on Single GPU
CUDA_VISIBLE_DEVICES=0 python rng.py
# Run on Multi GPU
CUDA_VISIBLE_DEVICES=0,1 python rng.py
"""

from typing import NamedTuple

import jax, jax.numpy as jnp
import haiku as hk
import optax

class TrainState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState

class Batch(NamedTuple):
    x: jax.Array
    y: jax.Array

def model_fn(x: jax.Array) -> jax.Array:
    def dropout(x: jax.Array) -> jax.Array:
        # return x
        return hk.dropout(hk.next_rng_key(), 0.1, x)

    def attn(x: jax.Array) -> jax.Array:
        residual = x
        x = hk.MultiHeadAttention(8, 64, w_init=hk.initializers.VarianceScaling(1.0), model_size=256)(x, x, x)
        x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
        x = dropout(x)
        x += residual
        return x

    def ff(x: jax.Array) -> jax.Array:
        residual = x
        x = hk.Linear(256 * 4)(x)
        x = jax.nn.relu(x)
        x = dropout(x)
        x = hk.Linear(256)(x)
        x = jax.nn.relu(x)
        x += residual
        return x

    for _ in range(10):
        x = attn(x)
        x = ff(x)

    x = hk.Linear(10)(x)
    x = jnp.max(x, axis=-2)
    return x

transformed = hk.transform(model_fn)
optimizer = optax.adam(1e-3)

@jax.jit
def init_fn(key: jax.random.KeyArray, batch: Batch) -> TrainState:
    x, _ = batch
    params = transformed.init(key, x)
    opt_state = optimizer.init(params)
    return TrainState(params, opt_state)

def loss_fn(params: hk.Params, key: jax.random.KeyArray, batch: Batch) -> jax.Array:
    x, y = batch
    logits = transformed.apply(params, key, x)
    return optax.softmax_cross_entropy(logits=logits, labels=y).mean()

@jax.jit
def step_fn(key: jax.random.KeyArray, state: TrainState, batch: Batch) -> TrainState:
    params, opt_state = state
    grads = jax.grad(loss_fn)(params, key, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return TrainState(new_params, new_opt_state)

if __name__ == "__main__":
    master_key = jax.random.PRNGKey(42)
    init_key, step_key = jax.random.split(master_key)

    batch_size = 256
    batch = Batch(
        x=jnp.zeros((batch_size, 64, 256)),
        y=jnp.zeros((batch_size, 10)),
    )
    state = init_fn(init_key, batch)

    sharding = jax.sharding.PositionalSharding(jax.devices())
    batch = jax.device_put(batch, jax.tree_util.tree_map(lambda x: sharding.reshape(-1, *[1] * (x.ndim-1)), batch))
    state = jax.device_put(state, sharding.replicate())

    state = step_fn(step_key, state, batch)

    # Time it
    import time
    start = time.perf_counter()
    n = 100
    for _ in range(n):
        state = step_fn(step_key, state, batch)
    end = time.perf_counter()
    elapsed = end - start
    per_step_ms = elapsed / n * 1000
    print(f"Time per step: {per_step_ms:.1f} ms")

@froystig A hopefully not too long code example (dispatch cost for small models obscure the difference). Compare return x with return hk.dropout(hk.next_rng_key(), 0.1, x), then run single & multi GPU training to see the difference.

My local experiment on two RTX 3090 Ti shows No dropout With dropout
Single GPU 66.4 ms 73.0 ms
Multi GPU (2) 45.8 ms 119.6 ms
froystig commented 1 year ago

Thank you!

yashk2810 commented 1 year ago

Try using jax.lax.with_sharding_constraint on dropout?

Specifically try this in def dropout: return jax.lax.with_sharding_constraint(kh.dropout(...), some_sharding)

jjyyxx commented 1 year ago

@yashk2810 I am afraid it's not a direct sharding issue. I further simplify the dropout function to this

    def dropout(x: jax.Array) -> jax.Array:
        def print_sharding(x: jax.Array, name: str):
            if not hk.running_init():
                jax.debug.inspect_array_sharding(x, callback=lambda sharding: print(name, sharding))

        print_sharding(x, "x before")

        keep = jax.random.bernoulli(jax.random.PRNGKey(42), 0.9, shape=x.shape)
        print_sharding(keep, "keep")

        x *= keep
        print_sharding(x, "x after")

        return x

For two GPU case, all calls to dropout print

x before GSPMDSharding({devices=[2,1,1]0,1})
keep GSPMDSharding({devices=[2,1,1]0,1})
x after GSPMDSharding({devices=[2,1,1]0,1})

It should also be noted that even a constant PRNG key jax.random.PRNGKey(42) for bernoulli could reproduce this. But of course this slowdown will go away with

        with jax.ensure_compile_time_eval():
            keep = jax.random.bernoulli(jax.random.PRNGKey(42), 0.9, shape=x.shape)
        # The following line is optional
        keep = jax.lax.with_sharding_constraint(keep, jax.sharding.PositionalSharding(jax.devices()).reshape(2, 1, 1))

Off-topic: could you leave some comment about #15734? The sharding attached to an array must be in a "broadcastable" form, which is especially counter-intuitive to use with PyTree?