Open zaccharieramzi opened 1 year ago
This looks very tricky. Naive question but is the fixed point well-defined when using dropout? Isn't randomness problematic?
Indeed the fixed point is not well-defined if the dropout is used without precaution. I would like to make sure that the mask stays the same between different calls. It's already a bit easier with https://github.com/google/flax/pull/3114 but still painful to handle cases where you are re-using pre-defined functions that do not allow you to specify the rng for the dropout, typically attention.
I have the following problem: I want to optimize/find the fixed point of a function that uses flax's
Dropout
and therefore under the hood it usesmake_rng
. What this means, is that when the function is called multiple times inside the optimizer/fixed point solver, the dropout mask will change between calls.For recurrent neural networks whose cells use dropout, this is fixed in flax using a
split_rngs
mechanism available in thescan
function. Namely, the doc says:The same is available for the while loop.
In my case, I would basically like to do something like
split_rngs={"dropout": False}
.I think the best way to tackle this would be to implement special while loops for flax cases in here. Happy to discuss the API and whether this feature makes sense.
My basic use case is to implement Deep Equilibrium Models where the fixed point defining function uses dropout.