google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.16k stars 648 forks source link

Dropout + `nn.jit` #3171

Open aaarrti opened 1 year ago

aaarrti commented 1 year ago

Hi there,

I was following this guide, flax.linen.Dropout. Then I decided to add nn.jit, and started getting MyModel.__call__() missing 1 required positional argument: 'training', even though I passed it.

System information

- Python version `python3.10`
- No hardware acceleration was used

### Problem you have encountered:

Traceback (most recent call last): File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 34, in main() File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 25, in main variables = my_model.init(params_key, x, training=False) TypeError: MyModel.call() missing 1 required positional argument: 'training'


### What you expected to happen:
I expect jitted and non-jitted version to work the same. Or am I missing something?

### Steps to reproduce:
```python
import jax
import jax.numpy as jnp
import flax.linen as nn

class MyModel(nn.Module):
    num_neurons: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.Dense(self.num_neurons)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        return x

def main():
    root_key = jax.random.PRNGKey(seed=0)
    main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
    my_model = nn.jit(MyModel)(num_neurons=3)
    x = jnp.empty((3, 4, 4))
    variables = my_model.init(params_key, x, training=False)

if __name__ == "__main__":
    main()

Question

I have saw the warning

/Users/artemsereda/miniconda3/envs/py310/lib/python3.10/site-packages/flax/core/lift.py:111: RuntimeWarning: kwargs are not supported in jit, so "training" is(are) ignored
  warnings.warn(msg.format(name, ', '.join(kwargs.keys())), RuntimeWarning)

so I decided to change my code to

variables = my_model.init(params_key, x, False)

which then resulted in

Traceback (most recent call last):
  File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 25, in <module>
    main()
  File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 21, in main
    variables = my_model.init(params_key, x, False)
  File "/Users/artemsereda/Documents/IdeaProjects/kaggle-bliss/bug_report.py", line 12, in __call__
    x = nn.Dropout(rate=0.5, deterministic=not training)(x)
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[].
The problem arose with the `bool` function. 
The error occurred while tracing the function core_fn at /Users/artemsereda/miniconda3/envs/py310/lib/python3.10/site-packages/flax/linen/transforms.py:305 for jit. This concrete value was not available in Python because it depends on the value of the argument args[2].

My next idea was to add the bool argument to static_argnums, as follows

def main():
    root_key = jax.random.PRNGKey(seed=0)
    main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
    my_model = nn.jit(MyModel, static_argnums=2)(num_neurons=3)
    x = jnp.empty((3, 4, 4))
    variables = my_model.init(params_key, x, False)

This one worked, but as per the documentation, "Calling the jitted function with different values for these constants will trigger recompilation.". The above-mentioned guide suggests using training=True for training steps, and training=False for validation steps, which will mean, I will have to re-compile full model twice in each training epoch.

Is there any way to address this?

cgarciae commented 1 year ago

Hey @aaarrti, currently nn.jit only support positional arguments which is why you are getting this error. However, after you pass training as positional you get another error which is that you must specify that training is a static_argnum. Here is the working example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class MyModel(nn.Module):
    num_neurons: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.Dense(self.num_neurons)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        return x

def main():
    root_key = jax.random.PRNGKey(seed=0)
    main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
    my_model = nn.jit(MyModel, static_argnums=(2,))(num_neurons=3)
    x = jnp.empty((3, 4, 4))
    variables = my_model.init(params_key, x, False)

if __name__ == "__main__":
    main()
cgarciae commented 1 year ago

That said, its pretty rare to use nn.jit. Usually you just use jax.jit over the train_step as shown in the Quick Start.

aaarrti commented 1 year ago

That said, its pretty rare to use nn.jit. Usually you just use jax.jit over the train_step as shown in the Quick Start.

Hi @cgarciae, it looks like I had a misconception about nn.jit. I though, it plays kind of similar role as tf.keras.Model.compile(). Am I wrong? What is the intended use case for nn.jit?

The example I provided is a basic one, but my original intention was to use nn.jit to compile a bigger model to speed up training. The workaround I came up with so far is:


import jax
import jax.numpy as jnp
import flax.linen as nn

class MyModel(nn.Module):
    num_neurons: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.jit(nn.Dense)(self.num_neurons)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        return x

def main():
    root_key = jax.random.PRNGKey(seed=0)
    main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
    my_model = MyModel(num_neurons=3)
    x = jnp.empty((3, 4, 4))
    variables = my_model.init(params_key, x, False)

if __name__ == "__main__":
    main()

It looks redundant for this toy example, but the performance improvements become easily noticeable as model size/complexity increases. I do also jax.jit my training step, though.

To sum up: