Open jlperla opened 2 days ago
To add to this: the performance of linen seems to be similar to NNX. Although I am even less clear how to profile there. Here was my implementation
import typing as tp
import jax
import jax.numpy as jnp
import flax.linen as linen
from flax.core import freeze, unfreeze
from flax.training import train_state
from flax.typing import Dtype, PrecisionLike
import optax
import time
class MLPLinen(linen.Module):
in_features: int
out_features: int
width: int
depth: int
activation: tp.Callable
use_bias: bool = True
use_final_bias: bool = True
final_activation: tp.Optional[tp.Callable] = None
dtype: tp.Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
@linen.compact
def __call__(self, x: jax.Array) -> jax.Array:
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
for _ in range(self.depth - 1):
x = self.activation(x)
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
x = linen.Dense(
self.out_features,
use_bias=self.use_final_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
if self.final_activation is not None:
x = self.final_activation(x)
return x
def create_train_state_linen(rng, model, learning_rate):
params = model.init(rng, jnp.ones([1, model.in_features]))['params']
tx = optax.adam(learning_rate)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
def compute_loss_linen(params, batch, model_apply_fn):
logits = model_apply_fn({'params': params}, batch)
loss = jnp.mean(logits)
return loss
@jax.jit
def train_step_linen(state, batch):
grad_fn = jax.value_and_grad(compute_loss_linen)
loss, grads = grad_fn(state.params, batch, state.apply_fn)
state = state.apply_gradients(grads=grads)
return state, loss
if __name__ == "__main__":
rng = jax.random.PRNGKey(0)
n_in = 64
n_out = 1
depth = 3
width = 128
activation = linen.relu
model = MLPLinen(n_in, n_out, width=width, depth=depth, activation=activation)
state = create_train_state_linen(rng, model, learning_rate=0.001)
my_batch = jax.random.normal(rng, (20, n_in))
# Time Linen
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
start_time = time.time()
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
end_time = time.time()
print(f"Time taken Linen: {end_time - start_time:.5f} seconds")
Hey @jlperla, can you use timeit
or similar to report the results? A single step, specially the first one that involves compilation is not very meaningful.
That said, this is what I would expect:
params
structure is a simple dictionary and you are using regular jax.jit
. JAX has optimized code to traverse dicts.@jlperla Maybe useful to note here, For small MLPs you are likely will be in the overhead regime. To overcome the framework overhead (in nnx or equinox) you may use nnx.{split,merge}
or equinox.{parition,combine}
pattern with non lifted jax transforms.
@ASEM000 correct. Ideally we document how to overcome the overhead problem in the near future.
@cgarciae @ASEM000 Absolutely. But the issue is comparing the relative overhead of NNX vs. Equinox for the same pattern? I find the timeit hard to use, but made sure things were compiled and retried multiple times?
Why the equinox code would be so much faster than NNX (which seems roughly similar to flax linen)? What is the overhead that would be so much more significant there, using the same coding pattern? If you look at my code I am isolating a single "value and grad" call, no optimizer overhead or training loop. And precompiling it before timing.
So either 0) It looks like my two sets of code are doing the same thing, but they really aren't. 1) I implemented the MLP poorly in NNX, which is VERY likely, and the one in equinox is done correctly. 2) There is some sort of overhead in the filtering process which is significantly more expensive in NNX vs. equinox. Maybe a manual split and combine (which can be done in both) would make it disappear
@jlperla I do imagine the NNX overhead being greater than the Equinox overhead as we do more bookkeeping and its not optimized. If performance is critical you should just train using split
/ merge
. Here is a modified version comparing both NNX and Equinox using low-overhead versions:
from functools import partial
import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
width: int,
depth: int,
activation: tp.Callable,
rngs: rnglib.Rngs,
use_bias: bool = True,
use_final_bias: bool = True,
final_activation: tp.Optional[tp.Callable] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
):
self.in_features = in_features
self.out_features = out_features
self.width = width
self.depth = depth
self.use_bias = use_bias
self.use_final_bias = use_final_bias
self.activation = activation
self.final_activation = final_activation
assert depth > 0 # skipping specialization of no hidden layers
self.layers = []
self.layers.append(
nnx.Linear(
in_features,
width,
use_bias=use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
for i in range(self.depth - 1):
self.layers.append(
nnx.Linear(
width,
width,
use_bias=self.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
self.layers.append(self.activation)
self.layers.append(
nnx.Linear(
width,
out_features,
use_bias=self.use_final_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
if self.final_activation is not None:
self.layers.append(self.final_activation)
def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers:
x = layer(x)
return x
if __name__ == '__main__':
rngs = nnx.Rngs(0)
@jax.jit
def my_test(batch, graphdef, state):
model = nnx.merge(graphdef, state)
def loss_closure(model):
return jnp.mean(jax.vmap(model)(batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
return loss_val
n_in = 64
n_out = 1
depth = 3
width = 128
activation = nnx.relu
model = MLP(
n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs
)
my_batch = jax.random.normal(rngs(), (20, n_in))
graphdef, state = nnx.split(model)
# Time NNX
out = my_test(my_batch, graphdef, state).block_until_ready()
start_time = time.time()
out = my_test(my_batch, graphdef, state).block_until_ready()
end_time = time.time()
print(f'Time taken NNX: {end_time - start_time:.5f} seconds')
# -----------
# Equinox
# -----------
@eqx.filter_jit
def my_test_eqx(batch, treedef, leaves):
model = jax.tree.unflatten(treedef, leaves)
@eqx.filter_jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
return loss_val
equinox_model = eqx.nn.MLP(
n_in,
n_out,
width_size=width,
depth=depth,
activation=activation,
key=rngs(),
)
leaves, treedef = jax.tree.flatten(equinox_model)
# Time Equinox
out = my_test_eqx(my_batch, treedef, leaves)
start_time = time.time()
out = my_test_eqx(my_batch, treedef, leaves).block_until_ready()
end_time = time.time()
print(f'Time taken EQX: {end_time - start_time:.5f} seconds')
Output on my M1:
Time taken NNX: 0.00007 seconds
Time taken EQX: 0.00019 seconds
This version might still be suboptimal for Equinox because of the use of eqx.filter_jit
instead of jax.jit
.
We will add a guide on NNX transforms explaining how they work under the hood in the future.
Some documentation would be very useful, also ran into this when profiling nnx vs. linen.
Linen is already low-overhead, I'll try to add it to the benchmark.
@cgarciaethanks, this helps a lot. I don't feel like you need to compare to equinox in your docs. My main concern was that it seemed to be 3x slower for the same task. But if you are doing more bookkeeping, then it isn't really the same task.
and just to confirm: my MLP implementation is as high performance as possible? If so, maybe that is helpful to have in the docs for people to adapt.
I believe so. @jlperla do you want to contribute it as an NNX example? While we don't let people directly import examples we can point to it on the documentation and it could serve as a reference implementation that people can easily copy into their codebase.
I decided to try the nnx vs. equinox for performance and am seeing significant differences (3'ish times slower for nnx). Could be that I wrote a poor MLP implementation or made a collosal profiling mistake.
My apologies if the benchmarking itself is flaws or the MLP implementation is incorrect in some way. But if it is the later, it shows that a documented MLP implementa`ton for NNX to copy/paste might help.
System information
Problem you have encountered:
The performance of my test suite on my CPU is
And on the colab T4 GPU runtime
Steps to reproduce:
Test Suite:
On colab you need to do
! pip install equinox