Open aaarrti opened 1 year ago
Hey @aaarrti, currently nn.jit
only support positional arguments which is why you are getting this error. However, after you pass training as positional you get another error which is that you must specify that training is a static_argnum
. Here is the working example:
import jax
import jax.numpy as jnp
import flax.linen as nn
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
def main():
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
my_model = nn.jit(MyModel, static_argnums=(2,))(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x, False)
if __name__ == "__main__":
main()
That said, its pretty rare to use nn.jit
. Usually you just use jax.jit
over the train_step
as shown in the Quick Start.
That said, its pretty rare to use
nn.jit
. Usually you just usejax.jit
over thetrain_step
as shown in the Quick Start.
Hi @cgarciae,
it looks like I had a misconception about nn.jit
. I though, it plays kind of similar role as tf.keras.Model.compile()
.
Am I wrong? What is the intended use case for nn.jit
?
The example I provided is a basic one, but my original intention was to use nn.jit
to compile a bigger model to speed up training. The workaround I came up with so far is:
import jax
import jax.numpy as jnp
import flax.linen as nn
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.jit(nn.Dense)(self.num_neurons)(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
def main():
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x, False)
if __name__ == "__main__":
main()
It looks redundant for this toy example, but the performance improvements become easily noticeable as model size/complexity increases. I do also jax.jit
my training step, though.
nn.jit
usagenn.jit
intended usage in the documentation
Hi there,
I was following this guide, flax.linen.Dropout. Then I decided to add
nn.jit
, and started gettingMyModel.__call__() missing 1 required positional argument: 'training'
, even though I passed it.System information
Traceback (most recent call last): File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 34, in
main()
File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 25, in main
variables = my_model.init(params_key, x, training=False)
TypeError: MyModel.call() missing 1 required positional argument: 'training'
Question
I have saw the warning
so I decided to change my code to
which then resulted in
My next idea was to add the
bool
argument tostatic_argnums
, as followsThis one worked, but as per the documentation, "Calling the jitted function with different values for these constants will trigger recompilation.". The above-mentioned guide suggests using
training=True
for training steps, andtraining=False
for validation steps, which will mean, I will have to re-compile full model twice in each training epoch.Is there any way to address this?