google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.79k stars 609 forks source link

static_argnums argument to flax.linen.remat not working as expected #3946

Open dionhaefner opened 1 month ago

dionhaefner commented 1 month ago

I am using flax.linen.remat on a module that has a train flag (used to check if the model is training). I'm using static_argnums on that flag, but am still getting a ConcretizationTypeError on model init.

Reproducer:

import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x, train):
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x

# Works fine with this line commented out
MLP = nn.remat(MLP, static_argnums=(1,))

model = MLP()
rng_key = jax.random.PRNGKey(42)
variables = model.init(rng_key, input_example, True)

Traceback:

$ python foo.py
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/dion/codes/supersede/foo.py", line 43, in <module>
    variables = model.init(rng_key, input_example, True)
jax.errors.ConcretizationTypeError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function new_fun at /Users/dion/.virtualenvs/science/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py:393 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument dyn_args[2].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Consider using the `static_argnums` parameter for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html

Tested with flax==0.8.4.

cgarciae commented 1 month ago

train is the 3rd argument so you have to change the static_argnums like this:

MLP = nn.remat(MLP, static_argnums=(2,))
dionhaefner commented 1 month ago

I see. I guess I got confused because sometimes our models are used like this:

model.apply(variables, inputs, train=False)

which triggers this error:

ValueError: the `static_argnums` argument to `jax.checkpoint` / `jax.remat` can only take integer values greater than or equal to `-len(args)` and less than `len(args)`, but got (3,)

So I assumed it wasn't counting the self argument. Any chance we could support something akin to static_argnames from jax.jit to support kwargs?