google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.14k stars 647 forks source link

Force no split in `make_rng` #3115

Open cgarciae opened 1 year ago

cgarciae commented 1 year ago

Discussed in https://github.com/google/flax/discussions/3113

Originally posted by **zaccharieramzi** May 24, 2023 I have the following situation: I am using a `Dropout` layer multiple times without a [`nn.scan` or `nn.while_loop`](https://github.com/google/flax/discussions/2920#discussioncomment-5180446), therefore I cannot use `split_rngs={"dropout": False}`. However, I would still like to use the same dropout mask twice. Is it possible to specify "no split" to make rng for certain collections? If I just take the original dropout example I would like to do something like: ```python # Setup. import jax import jax.numpy as jnp import flax.linen as nn # Randomness. seed = 0 root_key = jax.random.PRNGKey(seed=seed) main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3) # A simple network. class MyModel(nn.Module): num_neurons: int training: bool @nn.compact def __call__(self, x): x = nn.Dense(self.num_neurons)(x) # Set the dropout layer with a rate of 50% . # When the `deterministic` flag is `True`, dropout is turned off. x = nn.Dropout(rate=0.5, deterministic=not self.training)(x) x = nn.Dropout(rate=0.5, deterministic=not self.training)(x) return x # Instantiate `MyModel` (you don't need to set `training=True` to # avoid performing the forward pass computation). my_model = MyModel(num_neurons=3, training=False) x = jax.random.uniform(key=main_key, shape=(3, 4, 4)) # Initialize with `flax.linen.init()`. # The `params_key` is equivalent to a dictionary of PRNGs. # (Here, you are providing only one PRNG key.) variables = my_model.init(params_key, x) # Perform the forward pass with `flax.linen.apply()`. my_model.training = True y = my_model.apply(variables, x, rngs={'dropout': dropout_key}) ``` and still have `jnp.sum(y == 0.) / (3*4*3) == 0.5` approx. For more context I am actually trying to implement [Deep Equilibrium Models](https://arxiv.org/abs/1909.01377) using [`jaxopt`](https://jaxopt.github.io/stable/) and `flax`, where the fixed point defining function uses dropout. I also [tried to see](https://github.com/google/jaxopt/issues/432) if the `split_rngs` functionality could be extended to `jaxopt` but I think it's going to be difficult.
cgarciae commented 1 year ago

Hey @zaccharieramzi, I've converted the discussion into and issue as it seems something that we should improve. I've created #3114, which would allow you to optionally specify the rng key for each Dropout layer, e.g:

# A simple network.
class MyModel(nn.Module):
  num_neurons: int
  training: bool
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a rate of 50% .
    # When the `deterministic` flag is `True`, dropout is turned off.
    key = self.make_rng('dropout')
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
    return 

This way both layers will produce the same mask.

zaccharieramzi commented 1 year ago

Would there be a way to propagate this information rather than having to pass it around to each dropout? Indeed, in my case I would need to do key = self.make_rng("dropout") and pass it down to the actual dropout layers which are nested deep in different nn.Modules.

Something like:

key = self.make_rng('dropout')
x = MyModule(...)(x, rng=key)

where originally MyModule does not have the rng parameter in its API.

zaccharieramzi commented 1 year ago

ofc I understand it might be way more complex to do, so it's really just a question

zaccharieramzi commented 1 year ago

@cgarciae I see that this was closed so maybe you missed my earlier question. Typically in modules like dot_product_attention the dropout is hardcoded without the possibility to set the rng. Do you think it's best then to reimplement all these modules with the possibility to pass the rng?

chiamp commented 1 year ago

FYI @zaccharieramzi, I added a dropout_arg to nn.MultiHeadDotProductAttention in #3384 so you can get the same dropout mask