craffel / rnntools

RNN layers for use with nntools
1 stars 1 forks source link

LSTM training dynamics #2

Open Shumarci opened 9 years ago

Shumarci commented 9 years ago

First of, thanks for putting this together, it's a great starting point for something I'm trying to do with LSTMs. I also felt like the training dynamics are a little odd. Is it possible that this is a result of not resetting h (and c for LSTMs) for each new sequence? I couldn't find where in the code it would be resetting these, but keeping the hidden state across sequences seems like it would cause undesired effects. It basically would have to learn to identify the start of a new sequence and reset via the forget gate, which could explain why it works but the training dynamics are off a little bit. I am pretty new to theano, so I may very well be misreading the code, just thought I might mention it in case it is the problem.

craffel commented 9 years ago

Hey, you're right that the training dynamics are odd. This code is out of date, however, please refer to https://github.com/craffel/nntools/blob/master/nntools/layers/base.py#L345 In the most recent code, c and h are both reset to c_init and h_init at the beginning of each sequence (because of the semantics of scan). They aren't learned, as is how Alex Graves recommends in his thesis.

Shumarci commented 9 years ago

I think I might have fun the bug. Is it possible that:

c_t = ftc{t - 1} + it\tanh(W{xc}xt + W{hc}h_{t-1} + b_c)

        cell = (forget_gate_previous_cell +
                input_gate_self.nonlinearity_cell(
                    T.dot(layer_input, W_input_to_cell) +
                    T.dot(previous_cell, W_hidden_to_cell) +
                    b_cell))

should be:

cell = (forget_gate_previous_cell + input_gate_self.nonlinearity_cell( T.dot(layer_input, W_input_to_cell) + T.dot(_previousoutput, W_hidden_to_cell) + b_cell))

that seems to match the description in the comment (and the Graves paper) and also seems to fix some of the oddities that I was seeing.

there's also a typo in one of the comments:

ot = \sigma(W{xo}xt + W{ho}h{t-1} + W{co}c{t-1}_ + b_o)

the c{t-1} should be c{t}. It's correct in the code.

-Carsten

On Wed, Nov 19, 2014 at 8:24 AM, Colin Raffel notifications@github.com wrote:

Hey, you're right that the training dynamics are odd. This code is out of date, however, please refer to https://github.com/craffel/nntools/blob/master/nntools/layers/base.py#L345 In the most recent code, c and h are both reset to c_init and h_init at the beginning of each sequence (because of the semantics of scan). They aren't learned, as is how Alex Graves recommends in his thesis.

— Reply to this email directly or view it on GitHub https://github.com/craffel/rnntools/issues/2#issuecomment-63667181.

craffel commented 9 years ago

Oh man, you're definitely right! Thanks for spotting this. I feel like I went over this code a dozen times but I guess it's hard to spot something like this in your own code. Training still seems slow to converge though with a large learning rate on the problem I'm trying it on (compared to rnnlib and currennt). So there may be other issues... or I may just need to tweak learning rates/etc. more.

Shumarci commented 9 years ago

my guess was that the gradient clipping that's mentioned everywhere is also important, so I tried adding that. I'm having a little more luck now, but I'm also trying it on a tough problem, so it's hard to judge. I'm going to keep digging :)

On Wed, Nov 26, 2014 at 5:35 PM, Colin Raffel notifications@github.com wrote:

Oh man, you're definitely right! Thanks for spotting this. I feel like I went over this code a dozen times but I guess it's hard to spot something like this in your own code. Training still seems slow to converge though with a large learning rate on the problem I'm trying it on (compared to rnnlib and currennt). So there may be other issues... or I may just need to tweak learning rates/etc. more.

— Reply to this email directly or view it on GitHub https://github.com/craffel/rnntools/issues/2#issuecomment-64734869.

craffel commented 9 years ago

The problem I'm trying is here: https://github.com/craffel/lstm_benchmarks It includes code for doing what should be exactly the same thing in rnnlib, currennt, and nntools. The training dynamics are now actually much better; the main glaring difference is that rnnlib/currennt use a learning rate of 1e-5 but I can only get decently fast convergence with a learning rate of 1 or so, which is huge. I also am not using BLSTM (yet) in nntools... need to implement this before actually making a comparison. If you find anything else, please let me know! Thanks again!