Open arthurflor23 opened 4 months ago
Thanks for the bug! Definitely seems plausible, the dropout code for RNNs changed substantially between Keras 2 and 3.
Do you have any minimal reproductions you could share?
Do you have any minimal reproductions you could share?
Unfortunately, I can't share any right now.
Just the mentioned
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=128, dropout=0.5, return_sequences=True))
Against
decoder1 = tf.keras.layers.LSTM(units=128, dropout=0.0, go_backwards=False, return_sequences=True)(decoder)
decoder1 = tf.keras.layers.Dropout(rate=0.5)(decoder1)
decoder2 = tf.keras.layers.LSTM(units=128, dropout=0.0, go_backwards=True, return_sequences=True)(decoder)
decoder2 = tf.keras.layers.Dropout(rate=0.5)(decoder2)
decoder = tf.keras.layers.Concatenate()([decoder1, decoder2])
Thanks! I'll see if I can repro with some simple sequence data. Hopefully it's easy enough to tell where the dropout mask is going wrong just by comparing the two.
@arthurflor23 actually a little more info would be useful. Can you tell if you are using the GPU fused ops for LSTMs?
Might need to step in with a debugger, or add a print in the function to see if it's actually getting run.
@fchollet it looks to me like these lines are not equivalent with the Keras 2 version, but I am not sure.
I see how that's trying to imitate a cached dropout that is the same at each time step. But as far as I can tell, in Keras 2 on GPU the dropout is just computed once for the sequence without this broadcasting trick. Am I reading this wrong?
@mattdangerw I think you're right -- Keras 2 was applying input dropout randomly across timesteps (while using a temporally constant mask for recurrent dropout), while Keras 3 is using a temporally constant dropout mask for both input dropout and recurrent dropout. So there is a behavior difference.
The Keras 3 implementation is consistent with Deep Learning with Python, quoting:
the same dropout mask (the same pattern of dropped units) should be applied at every timestep, instead of a dropout mask that varies randomly from timestep to timestep. What's more, in order to regularize the representations formed by the recurrent gates of layers such as
GRU
andLSTM
, a temporally constant dropout mask should be applied to the inner recurrent activations of the layer (a recurrent dropout mask). Using the same dropout mask at every timestep allows the network to properly propagate its learning error through time; a temporally random dropout mask would disrupt this error signal and be harmful to the learning process.
If you want to replicate the Keras 2 behavior you can simply not pass dropout=
to the LSTM
layer, and instead add a Dropout
layer before the LSTM
layer.
If you think we should change the Keras 3 implementation, let me know.
@arthurflor23 actually a little more info would be useful. Can you tell if you are using the GPU fused ops for LSTMs?
Might need to step in with a debugger, or add a print in the function to see if it's actually getting run.
@mattdangerw I added a print in the functions and both are called in training (keras and tf-keras).
If you want to replicate the Keras 2 behavior you can simply not pass
dropout=
to theLSTM
layer, and instead add aDropout
layer before theLSTM
layer.
I couldn't replicate the expected results using Dropout before BLSTM and also LSTM (forward and backward) in Keras 3. However, somehow using TimeDistributed(Dropout) before BLSTM/LSTM worked. The tradeoff is the increased computational cost. Does that make any sense?
Anyway, I understand that the current implementation is correct through the presented concept, so let me know what you think @mattdangerw, if so, I could close the issue.
I finally managed to replicate the same behavior between Keras versions. For those who got here, two key points: apply dropout before each LSTM layer (both forward and backward) and use a flip operation after the backward call (I didn't know that):
forwards = tf.keras.layers.Dropout(rate=0.5)(decoder)
forwards = tf.keras.layers.LSTM(units=128, return_sequences=True, go_backwards=False)(forwards)
backwards = tf.keras.layers.Dropout(rate=0.5)(decoder)
backwards = tf.keras.layers.LSTM(units=128, return_sequences=True, go_backwards=True)(backwards)
backwards = tf.keras.ops.flip(backwards, axis=1)
decoder = tf.keras.layers.Concatenate(axis=-1)([forwards, backwards])
I tried simplifying by placing a dropout before the Bidirectional LSTM layer, but it resulted in different behavior. So, it's necessary to include dropouts before both the forward and backward LSTM layers.
Here, I implemented a custom Bidirectional layer with an additional dropout
parameter that applies dropout before each layer. With this, it preserves both dropout approaches: within the LSTM layer (as in Keras 3) and before the LSTM layer (as in Keras 2).
Bidirectional layer with an additional
dropout
parameter that applies dropout before each layer. With this, it preserves both dropout approaches: within the LSTM layer (as in Keras 3) and before the LSTM layer (as in Keras 2).
Another suggestion would be to add a noise_shape=None
parameter to the LSTM layer, just like the Dropout layer API has, allowing the flexibility of using the binary dropout mask.
Hi,
As mentioned here, a drop in model results was observed in the new version of TensorFlow using the new Keras. I tracked some tests using tensorflow with keras and tf-keras as well here
I'm using a CNN+BLSTM+CTC for handwriting recognition tests, and using only the APIs:
So I expanded the tests again and tried to find the root of the problem by looking for the differences in default parameters between keras and tf.keras. Below are the tests I did.
Recently, I found some implementation differences regarding the LSTM layer dropout (Keras3 and Keras2). So I tested some variations using tensorflow with keras 3:
So using dropout outside the LSTM layer, the model achieves the expected results.