Open richardmkit opened 1 month ago
You have to mark training
as a static argument when you jit your functions, so the compiler knows that you're ok with recompiling the function if its value changes. See: https://jax.readthedocs.io/en/latest/jit-compilation.html#marking-arguments-as-static
In short, change your @jax.jit
decorators to @partial(jax.jit, static_argnames=['training'])
should do the trick. I know it's a bit confusing because the flax dropout guide neglects to mention this.
The model could be successfully trained, when I add two dropout layers and don't use jax.jit. However, as long as I try to accelerate the training by jax.jit, it prompts the error
Seems something wrong with the flag training. How could I solve this? Thx.