Open JunhongXu opened 1 day ago
mlp.__call__
is not recommended as you are passing self
as a capture. Try MLP.__call__
and passing mlp
as the first input.
Just to clarify, what is happening is that mlp.__call__
is not traversing self
so its faster, a lot faster in this case.
We are going to be developing a Rust extension (see #4196) so in the future nnx.jit
should be fast. For now consider using this pattern to remove the python overhead.
I've created this mini guide to clarify the situation around performance: #4224.
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
pip show flax jax jaxlib
:flax
: 0.9.0,jax
: 0.4.30,jaxlib
: 0.4.30Problem you have encountered:
nnx.jit(aux_fn)
is slower than directly usingnnx.jit(model.__call__)
, whereaux_fn
is defined byFrom my understanding, I found that using an auxiliary function with
nnx.jit
seems a common practice and is required if we want to modify the internal state of the model (https://github.com/google/flax/discussions/3998). However, it seems slower than directly wrapping themodel.__call__
function usingnnx.jit
.See the colab link below to reproduce.
Steps to reproduce:
Colab link: https://colab.research.google.com/drive/1cGpcaBaJABUxhZuywgLZELZRwFsT5zve?usp=sharing
For completeness, I also copy the code here
The outputs using a RTX 4090 are: