Open dionhaefner opened 5 months ago
train
is the 3rd argument so you have to change the static_argnums
like this:
MLP = nn.remat(MLP, static_argnums=(2,))
I see. I guess I got confused because sometimes our models are used like this:
model.apply(variables, inputs, train=False)
which triggers this error:
ValueError: the `static_argnums` argument to `jax.checkpoint` / `jax.remat` can only take integer values greater than or equal to `-len(args)` and less than `len(args)`, but got (3,)
So I assumed it wasn't counting the self
argument. Any chance we could support something akin to static_argnames
from jax.jit
to support kwargs?
I am using
flax.linen.remat
on a module that has atrain
flag (used to check if the model is training). I'm usingstatic_argnums
on that flag, but am still getting aConcretizationTypeError
on model init.Reproducer:
Traceback:
Tested with
flax==0.8.4
.