google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
922 stars 64 forks source link

Implement a `split_rngs` mechanism, like in `flax.linen.scan` #432

Open zaccharieramzi opened 1 year ago

zaccharieramzi commented 1 year ago

I have the following problem: I want to optimize/find the fixed point of a function that uses flax's Dropout and therefore under the hood it uses make_rng. What this means, is that when the function is called multiple times inside the optimizer/fixed point solver, the dropout mask will change between calls.

For recurrent neural networks whose cells use dropout, this is fixed in flax using a split_rngs mechanism available in the scan function. Namely, the doc says:

split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.

The same is available for the while loop.

In my case, I would basically like to do something like split_rngs={"dropout": False}.

I think the best way to tackle this would be to implement special while loops for flax cases in here. Happy to discuss the API and whether this feature makes sense.

My basic use case is to implement Deep Equilibrium Models where the fixed point defining function uses dropout.

mblondel commented 1 year ago

This looks very tricky. Naive question but is the fixed point well-defined when using dropout? Isn't randomness problematic?

zaccharieramzi commented 1 year ago

Indeed the fixed point is not well-defined if the dropout is used without precaution. I would like to make sure that the mask stays the same between different calls. It's already a bit easier with https://github.com/google/flax/pull/3114 but still painful to handle cases where you are re-using pre-defined functions that do not allow you to specify the rng for the dropout, typically attention.