Closed mmcenta closed 4 years ago
Hi!
Basically something like (1) is the canonical way we do it! We use a with nn.stochastic(prng_key):
context around the model evaluation inside the training-step function, with a top-level prng-key fed into the training-step function from an outside split in the training loop. The important thing is that the training-step function is jitted in its entirety - the one thing not to do is to jit across a nn.stochastic
context. Provided it's used inside a jit and the prng-keys are fed like any other function argument, there should be no trouble using it. Using the stochastic context can save a lot of prng-key plumbing boilerplate in models.
I've typed up a quick demo of the canonical way in a colab at https://colab.research.google.com/drive/1eDXEVd8NPXgaSwn7jEMxsZTDVHNUtHYK
Let me know if that helps or if anything remains unclear!
Thank you for the amazing answer, I will get working on my project right away! Everything is clear, you even mentioned the part about passing the PRNG key itself while initializing a Module following 3, which something I had problems with.
I'm closing the issue, thanks for the help!
Hello!
First of all, thanks for this project - it is a lifesaver! So I wanted to get familiar with JAX so I decided to implement a few deep reinforcement learning algorithms as a side project. I initially approached the problem by subclassing flax.nn.Module as follows:
I very quickly ran into the problem of having a duplicate parameter 'rng'. I dug into the flax code and discovered that 1) dropout was not implemented as a module as I believed and 2) the ModuleFrame has an rng param that I can't access (apparently). I came up with three solutions:
Get
rng
from thenn.stochastic
context, but that would require wrapping the entire training function with it which seems a little weird to me.Use the same solution as in the VAE example and pass
rng
each time as a positional argument.Mix both solutions and try to get
rng
from a kwarg and if that fails fallback to the context. This may lead to a problem if someone sets the kwarg with a call topartial
...I wanted to ask you how you would go about this? My main concerns are code reusability and reproducibility.