Open froystig opened 1 year ago
@jjyyxx, do you have a minimal code example that reproduces this?
@froystig I could give it a try. But can a snippet containing haiku be considered minimal? I suspect at least a medium-sized model could reveal the difference. But I will try it first anyway.
"""
conda create -p .conda/jax python=3.9 numpy scipy ipykernel
conda activate .conda/jax
pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install optax dm-haiku
# Run on Single GPU
CUDA_VISIBLE_DEVICES=0 python rng.py
# Run on Multi GPU
CUDA_VISIBLE_DEVICES=0,1 python rng.py
"""
from typing import NamedTuple
import jax, jax.numpy as jnp
import haiku as hk
import optax
class TrainState(NamedTuple):
params: hk.Params
opt_state: optax.OptState
class Batch(NamedTuple):
x: jax.Array
y: jax.Array
def model_fn(x: jax.Array) -> jax.Array:
def dropout(x: jax.Array) -> jax.Array:
# return x
return hk.dropout(hk.next_rng_key(), 0.1, x)
def attn(x: jax.Array) -> jax.Array:
residual = x
x = hk.MultiHeadAttention(8, 64, w_init=hk.initializers.VarianceScaling(1.0), model_size=256)(x, x, x)
x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
x = dropout(x)
x += residual
return x
def ff(x: jax.Array) -> jax.Array:
residual = x
x = hk.Linear(256 * 4)(x)
x = jax.nn.relu(x)
x = dropout(x)
x = hk.Linear(256)(x)
x = jax.nn.relu(x)
x += residual
return x
for _ in range(10):
x = attn(x)
x = ff(x)
x = hk.Linear(10)(x)
x = jnp.max(x, axis=-2)
return x
transformed = hk.transform(model_fn)
optimizer = optax.adam(1e-3)
@jax.jit
def init_fn(key: jax.random.KeyArray, batch: Batch) -> TrainState:
x, _ = batch
params = transformed.init(key, x)
opt_state = optimizer.init(params)
return TrainState(params, opt_state)
def loss_fn(params: hk.Params, key: jax.random.KeyArray, batch: Batch) -> jax.Array:
x, y = batch
logits = transformed.apply(params, key, x)
return optax.softmax_cross_entropy(logits=logits, labels=y).mean()
@jax.jit
def step_fn(key: jax.random.KeyArray, state: TrainState, batch: Batch) -> TrainState:
params, opt_state = state
grads = jax.grad(loss_fn)(params, key, batch)
updates, new_opt_state = optimizer.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return TrainState(new_params, new_opt_state)
if __name__ == "__main__":
master_key = jax.random.PRNGKey(42)
init_key, step_key = jax.random.split(master_key)
batch_size = 256
batch = Batch(
x=jnp.zeros((batch_size, 64, 256)),
y=jnp.zeros((batch_size, 10)),
)
state = init_fn(init_key, batch)
sharding = jax.sharding.PositionalSharding(jax.devices())
batch = jax.device_put(batch, jax.tree_util.tree_map(lambda x: sharding.reshape(-1, *[1] * (x.ndim-1)), batch))
state = jax.device_put(state, sharding.replicate())
state = step_fn(step_key, state, batch)
# Time it
import time
start = time.perf_counter()
n = 100
for _ in range(n):
state = step_fn(step_key, state, batch)
end = time.perf_counter()
elapsed = end - start
per_step_ms = elapsed / n * 1000
print(f"Time per step: {per_step_ms:.1f} ms")
@froystig A hopefully not too long code example (dispatch cost for small models obscure the difference). Compare return x
with return hk.dropout(hk.next_rng_key(), 0.1, x)
, then run single & multi GPU training to see the difference.
My local experiment on two RTX 3090 Ti shows | No dropout | With dropout | |
---|---|---|---|
Single GPU | 66.4 ms | 73.0 ms | |
Multi GPU (2) | 45.8 ms | 119.6 ms |
Thank you!
Try using jax.lax.with_sharding_constraint
on dropout?
Specifically try this in def dropout
: return jax.lax.with_sharding_constraint(kh.dropout(...), some_sharding)
@yashk2810 I am afraid it's not a direct sharding issue. I further simplify the dropout
function to this
def dropout(x: jax.Array) -> jax.Array:
def print_sharding(x: jax.Array, name: str):
if not hk.running_init():
jax.debug.inspect_array_sharding(x, callback=lambda sharding: print(name, sharding))
print_sharding(x, "x before")
keep = jax.random.bernoulli(jax.random.PRNGKey(42), 0.9, shape=x.shape)
print_sharding(keep, "keep")
x *= keep
print_sharding(x, "x after")
return x
For two GPU case, all calls to dropout
print
x before GSPMDSharding({devices=[2,1,1]0,1})
keep GSPMDSharding({devices=[2,1,1]0,1})
x after GSPMDSharding({devices=[2,1,1]0,1})
It should also be noted that even a constant PRNG key jax.random.PRNGKey(42)
for bernoulli
could reproduce this. But of course this slowdown will go away with
with jax.ensure_compile_time_eval():
keep = jax.random.bernoulli(jax.random.PRNGKey(42), 0.9, shape=x.shape)
# The following line is optional
keep = jax.lax.with_sharding_constraint(keep, jax.sharding.PositionalSharding(jax.devices()).reshape(2, 1, 1))
Off-topic: could you leave some comment about #15734? The sharding attached to an array must be in a "broadcastable" form, which is especially counter-intuitive to use with PyTree?
Discussed in https://github.com/google/jax/discussions/15783