carlthome / tensorflow-convlstm-cell

A ConvLSTM cell with layer normalization and peepholes for TensorFlow's RNN API.
MIT License
399 stars 150 forks source link

Incorrect recurrent initialization leads to exploding gradients #33

Open JohnMBrandt opened 4 years ago

JohnMBrandt commented 4 years ago

Keras's implementation of Convolutional LSTM (https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM2D) uses recurrent_intializer = orthogonal to avoid exploding gradients.

This implementation uses the default initializer for the recurrent weight matrix, which is Glorot uniform. Long time series will have very bad exploding gradients when compared to the Keras layer.

This can be solved by W = tf.get_variable('kernel', self._kernel + [n, m], initializer =tensorflow.initializers.orthogonal)

carlthome commented 4 years ago

Interesting find! I'm not entirely convinced that exploding gradients is a risk for convolution kernels though, particularly if this output scaling is used. The orthogonal initialization paper (https://arxiv.org/abs/1312.6120) doesn't have a straight answer as far as I can see.

Long time series will have very bad exploding gradients when compared to the Keras layer.

Do you have any empirical evidence of this? Would be very cool to see!

I find it a bit hard to reason about orthogonal initialization for convolution kernels (https://hjweide.github.io/orthogonal-initialization-in-convolutional-layers) and even harder for the recurrent case. What precisely should be orthonormal? Feels like what we want is to encourage channels to not just be linear combinations of each other (redundant information). Is that what happens by configuring recurrent_intializer=orthogonal? Guess this is the relevant source: https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/init_ops_v2.py#L448-L515

JohnMBrandt commented 4 years ago

A few thoughts --

The original layer norm paper advise against applying it to CNNs:

"With fully connected layers, all the hidden units in a layer tend to make similar contributions to the final prediction and re-centering and re- scaling the summed inputs to a layer works well. However, the assumption of similar contributions is no longer true for convolutional neural networks. The large number of the hidden units whose receptive fields lie near the boundary of the image are rarely turned on and thus have very different statistics from the rest of the hidden units within the same layer"

I understand that the layer normalization is different in the convolutional recurrent case but layer norm still implicitly reasons that all time steps are equally important by setting their mean and std to be the same. This makes sense for temporal sequences where objects do not appear or disappear, but I have found that applying layer norm to a ConvLSTM with a segmentation task reduces performance, which I think is due to this recalibration of the feature map statistics.

However, it did reduce gradient explosion, and removing layer norm caused my network to have gradient explosion with 30 time steps -- to the point where I had to clip the norm of the gradients to 0.1 for stable learning. After swapping the initialization to orthogonal, I have been able to relax the gradient clipping to 5.0 with no gradient explosion, and segmentation works as expected without layer norm.

Again this may just be because I have a dataset that responds poorly to layer norm (objects disappear and reappear within one sample).

There is also some new networks which would seem to corroborate the notion that layer norm is bad for ConvLSTM because it incorrectly assumes that all hidden units are equally important. Attention ConvLSTM (https://papers.nips.cc/paper/7465-attention-in-convolutional-lstm-for-gesture-recognition.pdf) explicitly reweight the recurrent feature maps at each time step with something that resembles squeeze and excitation. This encourages the feature maps to have very different statistics at each time steps so that the network can identify, based on the input, how significant the time step is.