Open cgarciae opened 1 year ago
Hey @zaccharieramzi, I've converted the discussion into and issue as it seems something that we should improve.
I've created #3114, which would allow you to optionally specify the rng
key for each Dropout
layer, e.g:
# A simple network.
class MyModel(nn.Module):
num_neurons: int
training: bool
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a rate of 50% .
# When the `deterministic` flag is `True`, dropout is turned off.
key = self.make_rng('dropout')
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x, rng=key)
return
This way both layers will produce the same mask.
Would there be a way to propagate this information rather than having to pass it around to each dropout?
Indeed, in my case I would need to do key = self.make_rng("dropout")
and pass it down to the actual dropout layers which are nested deep in different nn.Module
s.
Something like:
key = self.make_rng('dropout')
x = MyModule(...)(x, rng=key)
where originally MyModule
does not have the rng
parameter in its API.
ofc I understand it might be way more complex to do, so it's really just a question
@cgarciae I see that this was closed so maybe you missed my earlier question. Typically in modules like dot_product_attention
the dropout is hardcoded without the possibility to set the rng.
Do you think it's best then to reimplement all these modules with the possibility to pass the rng?
FYI @zaccharieramzi, I added a dropout_arg
to nn.MultiHeadDotProductAttention
in #3384 so you can get the same dropout mask
Discussed in https://github.com/google/flax/discussions/3113