google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.09k stars 644 forks source link

Slow training step occasionally due to slow graph flatten #4336

Open kriscao-cohere opened 1 week ago

kriscao-cohere commented 1 week ago

I'm using NNX for a toy transformer on Wikitext-103, and I'm observing that one in every ~100 steps there's a step that takes much much longer (on the order of 2 seconds vs 0.02 seconds). I'm managed to trakc down the culprit with a profile, and it seems that there's sone NNX internal machinery in nnx.split that's taking the bulk of the time:

Image

Is there anything NNX-related that could be causing this to take a long time?

cgarciae commented 5 days ago

Thanks for posting this @kriscao-cohere! We do have a global context that keeps track of graph references during jit and all other transforms. That might be the first place I would look.

https://github.com/google/flax/blob/e4dad9ca2453b37da77cba790e49b35f4492fde9/flax/nnx/graph.py#L740

Hopefully the update_context context manager is not messing anything:

https://github.com/google/flax/blob/e4dad9ca2453b37da77cba790e49b35f4492fde9/flax/nnx/graph.py#L1046

I'll look into it. But if this is blocking you consider using regular jax.jit as detailed in Performance Considerations for NNX.

cgarciae commented 5 days ago

Just as a sanity check I ran this simple training code performing a print(nnx.graph.GRAPH_CONTEXT) after each step but all the context stacks are empty. Maybe the python garbage collector is hitting a spike?

training code ```python import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax from flax import nnx X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) def dataset(batch_size): while True: idx = np.random.choice(len(X), size=batch_size) yield X[idx], Y[idx] class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): return x @ self.w.value + self.b.value class Count(nnx.Variable): pass class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) self.linear1 = Linear(din, dhidden, rngs=rngs) self.linear2 = Linear(dhidden, dout, rngs=rngs) def __call__(self, x): self.count.value += 1 x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return x model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx) @nnx.jit def train_step(model: MLP, optimizer: nnx.Optimizer, batch): x, y = batch def loss_fn(model: MLP): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads: nnx.State = nnx.grad(loss_fn)(model) optimizer.update(grads) @nnx.jit def test_step(model: MLP, batch): x, y = batch y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} total_steps = 10_000 for step, batch in enumerate(dataset(32)): train_step(model, optimizer, batch) print(nnx.graph.GRAPH_CONTEXT) if step % 1000 == 0: logs = test_step(model, (X, Y)) print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break print('times called:', model.count.value) y_pred = model(X) plt.scatter(X, Y, color='blue') plt.plot(X, y_pred, color='black') plt.show() ```
kriscao-cohere commented 5 days ago

Hi @cgarciae, thanks for the updates! And thanks for pointing me to the jax-only pattern for using NNX, I tried it and it eliminated the wait time (and also sped up my experiment loop a lot, but I am doing small model experiments).

As for repro, I only saw it with certain model architectures (mainly Transformers past a certain layer depth), and I tried to repro the slow nnx.split only by running nnx.split by itself 100 times, but there were no unexpected slowdowns. It was only when I did 100 model forward passes (even on dummy data) that I would see the occasional slow step caused by nnx graph traversal.