google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k stars 610 forks source link

Significant performance difference of NNX relative to equinox #4045

Open jlperla opened 2 days ago

jlperla commented 2 days ago

I decided to try the nnx vs. equinox for performance and am seeing significant differences (3'ish times slower for nnx). Could be that I wrote a poor MLP implementation or made a collosal profiling mistake.

My apologies if the benchmarking itself is flaws or the MLP implementation is incorrect in some way. But if it is the later, it shows that a documented MLP implementa`ton for NNX to copy/paste might help.

System information

Problem you have encountered:

The performance of my test suite on my CPU is

Time taken NNX: 0.00055 seconds
Time taken EQX: 0.00019 seconds

And on the colab T4 GPU runtime

Time taken NNX: 0.00220 seconds
Time taken EQX: 0.00066 seconds

Steps to reproduce:

Test Suite:

import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time

class MLP(nnx.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        width: int,
        depth: int,
        activation: tp.Callable,
        rngs: rnglib.Rngs,
        use_bias: bool = True,
        use_final_bias: bool = True,
        final_activation: tp.Optional[tp.Callable] = None,
        dtype: tp.Optional[Dtype] = None,
        param_dtype: Dtype = jnp.float32,
        precision: PrecisionLike = None,
    ):
        self.in_features = in_features
        self.out_features = out_features
        self.width = width
        self.depth = depth
        self.use_bias = use_bias
        self.use_final_bias = use_final_bias
        self.activation = activation
        self.final_activation = final_activation
        assert depth > 0  # skipping specialization of no hidden layers

        self.layers = []
        self.layers.append(
            nnx.Linear(
                in_features,
                width,
                use_bias=use_bias,
                dtype=dtype,
                param_dtype=param_dtype,
                precision=precision,
                rngs=rngs,
            )
        )
        for i in range(self.depth - 1):
            self.layers.append(
                nnx.Linear(
                    width,
                    width,
                    use_bias=self.use_bias,
                    dtype=dtype,
                    param_dtype=param_dtype,
                    precision=precision,
                    rngs=rngs,
                )
            )
            self.layers.append(self.activation)
        self.layers.append(
            nnx.Linear(
                width,
                out_features,
                use_bias=self.use_final_bias,
                dtype=dtype,
                param_dtype=param_dtype,
                precision=precision,
                rngs=rngs,
            )
        )
        if self.final_activation is not None:
            self.layers.append(self.final_activation)

    def __call__(self, x: jax.Array) -> jax.Array:
        for layer in self.layers:
            x = layer(x)
        return x

if __name__ == "__main__":
    rngs = nnx.Rngs(0)

    @nnx.jit
    def my_test(batch, model):
        @nnx.jit
        def loss_closure(f):
            return jnp.mean(jax.vmap(f)(batch))

        loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
        return loss_val
    n_in = 64
    n_out = 1
    depth = 3
    width = 128
    activation = nnx.relu
    model = MLP(n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs)
    my_batch = jax.random.normal(rngs(), (20, n_in))

    # Time NNX
    out = my_test(my_batch, model).block_until_ready()
    start_time = time.time()
    out = my_test(my_batch, model).block_until_ready()
    end_time = time.time()
    print(f"Time taken NNX: {end_time - start_time:.5f} seconds")

    @eqx.filter_jit
    def my_test_eqx(batch, model):
        @eqx.filter_jit
        def loss_closure(f):
            return jnp.mean(jax.vmap(f)(batch))

        loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
        return loss_val    
    equinox_model = eqx.nn.MLP(n_in, n_out, width_size=width, depth=depth, activation=activation, key=rngs())

    # Time Equinox
    out = my_test_eqx(my_batch, equinox_model)
    start_time = time.time()
    out = my_test_eqx(my_batch, equinox_model).block_until_ready()
    end_time = time.time()
    print(f"Time taken EQX: {end_time - start_time:.5f} seconds")    

On colab you need to do ! pip install equinox

jlperla commented 2 days ago

To add to this: the performance of linen seems to be similar to NNX. Although I am even less clear how to profile there. Here was my implementation

import typing as tp
import jax
import jax.numpy as jnp
import flax.linen as linen
from flax.core import freeze, unfreeze
from flax.training import train_state
from flax.typing import Dtype, PrecisionLike
import optax
import time

class MLPLinen(linen.Module):
    in_features: int
    out_features: int
    width: int
    depth: int
    activation: tp.Callable
    use_bias: bool = True
    use_final_bias: bool = True
    final_activation: tp.Optional[tp.Callable] = None
    dtype: tp.Optional[Dtype] = None
    param_dtype: Dtype = jnp.float32
    precision: PrecisionLike = None

    @linen.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        x = linen.Dense(
            self.width, 
            use_bias=self.use_bias,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
        )(x)
        for _ in range(self.depth - 1):
            x = self.activation(x)
            x = linen.Dense(
                self.width, 
                use_bias=self.use_bias,
                dtype=self.dtype,
                param_dtype=self.param_dtype,
                precision=self.precision,
            )(x)
        x = linen.Dense(
            self.out_features, 
            use_bias=self.use_final_bias,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
        )(x)
        if self.final_activation is not None:
            x = self.final_activation(x)
        return x

def create_train_state_linen(rng, model, learning_rate):
    params = model.init(rng, jnp.ones([1, model.in_features]))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def compute_loss_linen(params, batch, model_apply_fn):
    logits = model_apply_fn({'params': params}, batch)
    loss = jnp.mean(logits)
    return loss

@jax.jit
def train_step_linen(state, batch):
    grad_fn = jax.value_and_grad(compute_loss_linen)
    loss, grads = grad_fn(state.params, batch, state.apply_fn)
    state = state.apply_gradients(grads=grads)
    return state, loss

if __name__ == "__main__":
    rng = jax.random.PRNGKey(0)
    n_in = 64
    n_out = 1
    depth = 3
    width = 128
    activation = linen.relu
    model = MLPLinen(n_in, n_out, width=width, depth=depth, activation=activation)

    state = create_train_state_linen(rng, model, learning_rate=0.001)
    my_batch = jax.random.normal(rng, (20, n_in))

    # Time Linen
    state, loss_val = train_step_linen(state, my_batch)
    jax.block_until_ready(loss_val)
    start_time = time.time()
    state, loss_val = train_step_linen(state, my_batch)
    jax.block_until_ready(loss_val)
    end_time = time.time()
    print(f"Time taken Linen: {end_time - start_time:.5f} seconds")
cgarciae commented 2 days ago

Hey @jlperla, can you use timeit or similar to report the results? A single step, specially the first one that involves compilation is not very meaningful.

That said, this is what I would expect:

ASEM000 commented 1 day ago

@jlperla Maybe useful to note here, For small MLPs you are likely will be in the overhead regime. To overcome the framework overhead (in nnx or equinox) you may use nnx.{split,merge} or equinox.{parition,combine} pattern with non lifted jax transforms.

cgarciae commented 19 hours ago

@ASEM000 correct. Ideally we document how to overcome the overhead problem in the near future.

jlperla commented 19 hours ago

@cgarciae @ASEM000 Absolutely. But the issue is comparing the relative overhead of NNX vs. Equinox for the same pattern? I find the timeit hard to use, but made sure things were compiled and retried multiple times?

Why the equinox code would be so much faster than NNX (which seems roughly similar to flax linen)? What is the overhead that would be so much more significant there, using the same coding pattern? If you look at my code I am isolating a single "value and grad" call, no optimizer overhead or training loop. And precompiling it before timing.

So either 0) It looks like my two sets of code are doing the same thing, but they really aren't. 1) I implemented the MLP poorly in NNX, which is VERY likely, and the one in equinox is done correctly. 2) There is some sort of overhead in the filtering process which is significantly more expensive in NNX vs. equinox. Maybe a manual split and combine (which can be done in both) would make it disappear

cgarciae commented 16 hours ago

@jlperla I do imagine the NNX overhead being greater than the Equinox overhead as we do more bookkeeping and its not optimized. If performance is critical you should just train using split / merge. Here is a modified version comparing both NNX and Equinox using low-overhead versions:

from functools import partial
import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time

class MLP(nnx.Module):
  def __init__(
    self,
    in_features: int,
    out_features: int,
    *,
    width: int,
    depth: int,
    activation: tp.Callable,
    rngs: rnglib.Rngs,
    use_bias: bool = True,
    use_final_bias: bool = True,
    final_activation: tp.Optional[tp.Callable] = None,
    dtype: tp.Optional[Dtype] = None,
    param_dtype: Dtype = jnp.float32,
    precision: PrecisionLike = None,
  ):
    self.in_features = in_features
    self.out_features = out_features
    self.width = width
    self.depth = depth
    self.use_bias = use_bias
    self.use_final_bias = use_final_bias
    self.activation = activation
    self.final_activation = final_activation
    assert depth > 0  # skipping specialization of no hidden layers

    self.layers = []
    self.layers.append(
      nnx.Linear(
        in_features,
        width,
        use_bias=use_bias,
        dtype=dtype,
        param_dtype=param_dtype,
        precision=precision,
        rngs=rngs,
      )
    )
    for i in range(self.depth - 1):
      self.layers.append(
        nnx.Linear(
          width,
          width,
          use_bias=self.use_bias,
          dtype=dtype,
          param_dtype=param_dtype,
          precision=precision,
          rngs=rngs,
        )
      )
      self.layers.append(self.activation)
    self.layers.append(
      nnx.Linear(
        width,
        out_features,
        use_bias=self.use_final_bias,
        dtype=dtype,
        param_dtype=param_dtype,
        precision=precision,
        rngs=rngs,
      )
    )
    if self.final_activation is not None:
      self.layers.append(self.final_activation)

  def __call__(self, x: jax.Array) -> jax.Array:
    for layer in self.layers:
      x = layer(x)
    return x

if __name__ == '__main__':
  rngs = nnx.Rngs(0)

  @jax.jit
  def my_test(batch, graphdef, state):
    model = nnx.merge(graphdef, state)

    def loss_closure(model):
      return jnp.mean(jax.vmap(model)(batch))

    loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
    return loss_val

  n_in = 64
  n_out = 1
  depth = 3
  width = 128
  activation = nnx.relu
  model = MLP(
    n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs
  )
  my_batch = jax.random.normal(rngs(), (20, n_in))
  graphdef, state = nnx.split(model)

  # Time NNX
  out = my_test(my_batch, graphdef, state).block_until_ready()
  start_time = time.time()
  out = my_test(my_batch, graphdef, state).block_until_ready()
  end_time = time.time()
  print(f'Time taken NNX: {end_time - start_time:.5f} seconds')

  # -----------
  # Equinox
  # -----------

  @eqx.filter_jit
  def my_test_eqx(batch, treedef, leaves):
    model = jax.tree.unflatten(treedef, leaves)

    @eqx.filter_jit
    def loss_closure(f):
      return jnp.mean(jax.vmap(f)(batch))

    loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
    return loss_val

  equinox_model = eqx.nn.MLP(
    n_in,
    n_out,
    width_size=width,
    depth=depth,
    activation=activation,
    key=rngs(),
  )

  leaves, treedef = jax.tree.flatten(equinox_model)

  # Time Equinox
  out = my_test_eqx(my_batch, treedef, leaves)
  start_time = time.time()
  out = my_test_eqx(my_batch, treedef, leaves).block_until_ready()
  end_time = time.time()
  print(f'Time taken EQX: {end_time - start_time:.5f} seconds')

Output on my M1:

Time taken NNX: 0.00007 seconds
Time taken EQX: 0.00019 seconds

This version might still be suboptimal for Equinox because of the use of eqx.filter_jit instead of jax.jit.

cgarciae commented 16 hours ago

We will add a guide on NNX transforms explaining how they work under the hood in the future.

JesseFarebro commented 15 hours ago

Some documentation would be very useful, also ran into this when profiling nnx vs. linen.

cgarciae commented 13 hours ago

Linen is already low-overhead, I'll try to add it to the benchmark.

jlperla commented 13 hours ago

@cgarciaethanks, this helps a lot. I don't feel like you need to compare to equinox in your docs. My main concern was that it seemed to be 3x slower for the same task. But if you are doing more bookkeeping, then it isn't really the same task.

and just to confirm: my MLP implementation is as high performance as possible? If so, maybe that is helpful to have in the docs for people to adapt.

cgarciae commented 12 hours ago

I believe so. @jlperla do you want to contribute it as an NNX example? While we don't let people directly import examples we can point to it on the documentation and it could serve as a reference implementation that people can easily copy into their codebase.