Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
938 stars 314 forks source link

How can I learn meaningful steps? #407

Open ChenSongle opened 7 years ago

ChenSongle commented 7 years ago

How can I learn meaningful action model? I have applied recurrent attention model (RAM) to video classification. In this application, image is replaced by video as the input. Each video contains 50 frames and the number of training samples is about 5000. Before training, I use VGG16 to extract the frame feature, which is a vector of 4096 dims. According to the theory of RAM, I thought it can select several frames and then emit a predication . However, I found RAM can't select meaningful glimpses (or frames). After training, the training accuracy can reach 99.96%, but when I applied the trained model to validation data, in every step except the first, it always selects the first (or last) frames, regardless the class the input video belongs. As a result, the validation accuracy is not desirable. I traced the procedure of the training and I found there are two factors leading to such result:

  1. The gradient of the hidden state of RNN comes both from the action reward (treat also as loss) and classification loss. In my training, the classification task is easy to overfitting, so the action model have no chance to learn good policy before the classification task overfits to the data. I changed the reward scale and other parameters, the problem still exists. According to other's experience, In the fist 100 epochs, I reset the parameters of the full connect before softmax layer for classification after each 10 epochs, but this strategy still can't solve my problem.
  2. Another problem. The HardTanh layer is used to constrain the input of ReinforceNormal to [-1, 1]. Unfortunately, if the input of HardTanh is out of [-1, 1], the gradient will be loss, so the weight of the action model can not be learned. I used sigmod to replace HadTanh, or removed the HardTanh layer. the problem still exists and the action model still can't learn right weight. This can explain the reason why when applies to the validation data, the trained model (in every step except the first, always selects the first (or last) frames, regardless the class the input video belongs. I observed the training procedure of RAM on Mnist dataset, in the first server epochs, the gradient exists, then after a number of epochs, the gradient of action is also lost, but after while, the gradient of action is back again. Therefore the action model is trained correctly.

    I guess:
    1) Maybe, using a frame as a glimpse is too much, leading the classification model easy to overfit, considering the glimpse of RAM on Mnist is only 88 and the whole image is 2828. 2) The feature extracted from VGG16 is not discriminative or suitable.

    How can I lean meaningful steps? Please give me some suggestions. Thank you very much.