lmjohns3 / theanets

Neural network toolkit for Python
http://theanets.rtfd.org
MIT License
328 stars 74 forks source link

lstm example and some changes to softmax #56

Closed hknerdgn closed 9 years ago

lmjohns3 commented 9 years ago

This looks nice! I hope you don't mind if I make a few style changes and update the example to work with py3k.

hknerdgn commented 9 years ago

Definitely feel free to do so. I was hoping that you do some style changes!

Currently, I have difficulty getting multi-layer LSTM converging as fast as currennt software on the same example. One issue is I do not handle the sequences right since I concatenate all sequences together which is not the right thing to do but it should not affect the results too much. Also, during validation I do not separate sequences but concatenate them all together.

I think it would help to be able to handle different sequence lengths. I need to write code that says ignore these frames (or steps) at the end of a sequence which do not correspond to real data. This can be done using a mask parameter for the targets. I think this warrants special handling in theanets since unequal sequence lengths are always going to happen.

currennt converges in less than 100 epochs to about 20% error on the validation data using a 3 layer bidirectional LSTM setup.

I tried currennt with LSTM (no bidirectionality) with layers (39,100,200,76,51) which converges to 28% error in less than 70 epochs.

currennt uses NAG for training. Objective function uses sum of negative logs instead of the mean of them like in theanets. So, I try to adjust the learning rate accordingly in theanets.

But, I achieved 45% error in 400+ epochs in theanets which is worse than a single layer LSTM. Next, I will try layerwise training which should work much better I think since single layer works.

Any ideas about why this may be happening?

Hakan

On Thu, Jan 8, 2015 at 12:10 PM, Leif Johnson notifications@github.com wrote:

This looks nice! I hope you don't mind if I make a few style changes and update the example to work with py3k.

— Reply to this email directly or view it on GitHub https://github.com/lmjohns3/theanets/pull/56#issuecomment-69212275.

lmjohns3 commented 9 years ago

Well, I checked in a rather large pruning-down of the code in your example. Again, I hope you don't mind.

I expanded the size of the LSTM layer to 500 but didn't see a huge improvement in performance. I'm not sure why this might be happening, but it might be a good thing to discuss on the mailing list?

Also, I see what you mean about supporting masks for network outputs. I've filed an issue to keep track of this: #58.

lmjohns3 commented 9 years ago

Also, when you say that you tried curennt with layers (39, 100, 200, 76, 51), what does that mean? Is that the number of units in each layer of a network with an input, 3 hidden layers, and an output? Are all of those layers LSTM, or just one of them?

You can create a network with lots of layers in theanets by giving the layers tuple (39, 100, 200, 76, 51). Then you can specify which of those layers is to be recurrent by providing recurrent_layers=(1, 3) (for example; would make the 100 and 76 layers be recurrent).

hknerdgn commented 9 years ago

Yes, I realized yesterday that I was doing it wrong. Actually, in currennt it is 3 hidden LSTM layers on top of each other.

In theanets, I was using only one LSTM in the middle and it was not the same, however I assumed they were the same but found out that they were not yesterday.

I need to correct this and re-run it later. I guess that may be the reason of discrepancy in error rate achieved.

Also, a reverse LSTM can be done using go_backwards argument of scan. One forward and one backward LSTM layer outputs can be concatenated in another layer on top of them to obtain a bidirectional LSTM. Or one can write a bidirectional LSTM layer. The mask thing is also very important for the bidirectional one.

Thanks.

Hakan

On Fri, Jan 9, 2015 at 12:07 PM, Leif Johnson notifications@github.com wrote:

Also, when you say that you tried curennt with layers (39, 100, 200, 76, 51), what does that mean? Is that the number of units in each layer of a network with an input, 3 hidden layers, and an output? Are all of those layers LSTM, or just one of them?

You can create a network with lots of layers in theanets by giving the layers tuple (39, 100, 200, 76, 51). Then you can specify which of those layers is to be recurrent by providing recurrent_layers=(1, 3) (for example; would make the 100 and 76 layers be recurrent).

— Reply to this email directly or view it on GitHub https://github.com/lmjohns3/theanets/pull/56#issuecomment-69364772.

hknerdgn commented 9 years ago

There seems to be a problem with training deep networks with more than one hidden layer, including feedforward ones. I experimented with mnist_classifier and lstm_chime examples. It may be expected due to vanishing gradients but even layerwise training does not seem to work. For the lstm chime example, we know currennt works with multiple hidden layers and it uses simple stochastic gradient descent with momentum without further tricks. In theanets, I almost always get better results with shallow networks as compared to the deep ones. It is very hard to see what may be the problem though. I am not sure if the problem is with theano or theanets. Any ideas?

Hakan

On Fri, Jan 9, 2015 at 2:09 PM, Hakan Erdogan haerdogan@sabanciuniv.edu wrote:

Yes, I realized yesterday that I was doing it wrong. Actually, in currennt it is 3 hidden LSTM layers on top of each other.

In theanets, I was using only one LSTM in the middle and it was not the same, however I assumed they were the same but found out that they were not yesterday.

I need to correct this and re-run it later. I guess that may be the reason of discrepancy in error rate achieved.

Also, a reverse LSTM can be done using go_backwards argument of scan. One forward and one backward LSTM layer outputs can be concatenated in another layer on top of them to obtain a bidirectional LSTM. Or one can write a bidirectional LSTM layer. The mask thing is also very important for the bidirectional one.

Thanks.

Hakan

On Fri, Jan 9, 2015 at 12:07 PM, Leif Johnson notifications@github.com wrote:

Also, when you say that you tried curennt with layers (39, 100, 200, 76, 51), what does that mean? Is that the number of units in each layer of a network with an input, 3 hidden layers, and an output? Are all of those layers LSTM, or just one of them?

You can create a network with lots of layers in theanets by giving the layers tuple (39, 100, 200, 76, 51). Then you can specify which of those layers is to be recurrent by providing recurrent_layers=(1, 3) (for example; would make the 100 and 76 layers be recurrent).

— Reply to this email directly or view it on GitHub https://github.com/lmjohns3/theanets/pull/56#issuecomment-69364772.

lmjohns3 commented 9 years ago

Yeah, not sure what might be the problem here. As far as I can tell the optimization code in theanets is correct (for what it's worth, the optimization code in benanne/Lasagne is basically the same). I unfortunately have a hard time figuring out how currennt works -- I've pored over the lstm code there and it looks to me like it basically shares the same implementation as theanets.

Also, I've noticed that the mnist-autoencoder.py example reaches about the same level of performance with anywhere between 1 and 4 hidden layers (using the rmsprop optimizer), so it does seem like the trainers are working.

But now I am starting to wonder: maybe the problem is in the cross entropy cost or the softmax that defines the output of the classifier networks? I'll see if I can poke a little at that tonight.

kastnerkyle commented 9 years ago

These optimizers are pretty hard to get right. In my implementation of LSTMs I needed gradient clipping, max_col_norm, and a bunch of other tiny tricks. It is a work in progress here if you are curious https://github.com/kastnerkyle/net/blob/master/net.py

Mostly in the TrainingMixin class. But there are some hacks in there right now, trying to figure out where the best time to clip is. We have some small disagreement between code bases in our lab about "where to clip" :)

On Mon, Jan 12, 2015 at 9:51 PM, Leif Johnson notifications@github.com wrote:

Yeah, not sure what might be the problem here. As far as I can tell the optimization code in theanets is correct (for what it's worth, the optimization code in benanne/Lasagne is basically the same). I unfortunately have a hard time figuring out how currennt works -- I've pored over the lstm code there and it looks to me like it basically shares the same implementation as theanets.

Also, I've noticed that the mnist-autoencoder.py example reaches about the same level of performance with anywhere between 1 and 4 hidden layers (using the rmsprop optimizer), so it does seem like the trainers are working.

But now I am starting to wonder: maybe the problem is in the cross entropy cost or the softmax that defines the output of the classifier networks? I'll see if I can poke a little at that tonight.

— Reply to this email directly or view it on GitHub https://github.com/lmjohns3/theanets/pull/56#issuecomment-69687715.

lmjohns3 commented 9 years ago

Yeah, that's some dark magic right there! :)

Hakan, if you want to give it a try, you can add gradient clipping with gradient_clip=10 (or whatever you want the max gradient norm to be) when you train your model. I haven't messed around with that myself, but it's in theanets because I read about it in a paper. :) Kyle, FWIW I've implemented the "clip the gradient before adding to momentum" method.

kastnerkyle commented 9 years ago

Cool, I moving there too. Clipping after momentum felt hacky... hopefully this will work better. Can't wait for LSTM to be in theanets, it should give lots of people a chance to try interesting stuff. Maybe I can contribute some demos once I find something interesting that is pretty easy.

On Mon, Jan 12, 2015 at 10:10 PM, Leif Johnson notifications@github.com wrote:

Yeah, that's some dark magic right there! :)

Hakan, if you want to give it a try, you can add gradient clipping with gradient_clip=10 (or whatever you want the max gradient norm to be) when you train your model. I haven't messed around with that myself, but it's in theanets because I read about it in a paper. :) Kyle, FWIW I've implemented the "clip the gradient before adding to momentum" method.

— Reply to this email directly or view it on GitHub https://github.com/lmjohns3/theanets/pull/56#issuecomment-69689182.

hknerdgn commented 9 years ago

Yes, I found out currennt also does gradient clipping (limitedError function) for the LSTM gradients which I did not pay much attention before.

Still, I could not get mnist classifier example work with deeper nets in theanets. That one should work with layer wise training. I suspected softmax or cross entropy cost as well for that. May be it helps to add a small number before taking the logarithm in the cross-entropy cost.

Hakan

On Mon, Jan 12, 2015 at 10:10 PM, Leif Johnson notifications@github.com wrote:

Yeah, that's some dark magic right there! :)

Hakan, if you want to give it a try, you can add gradient clipping with gradient_clip=10 (or whatever you want the max gradient norm to be) when you train your model. I haven't messed around with that myself, but it's in theanets because I read about it in a paper. :) Kyle, FWIW I've implemented the "clip the gradient before adding to momentum" method.

— Reply to this email directly or view it on GitHub https://github.com/lmjohns3/theanets/pull/56#issuecomment-69689182.

kastnerkyle commented 9 years ago

I prefer to clip the values before the logarithm instead of always adding a constant, but same idea. Just want to keep log from blowing up - I use T.clip(X, 1E-12, 1E12) and it seems ok. though in theory probability should be between 0 and 1 so T.clip(X, 1E-12, 1.) might be reasonable too...

hknerdgn commented 9 years ago

OK. Could the multi-layer problem related to this subtle issue here:

https://github.com/benanne/Lasagne/issues/97

What I mean is, could it be the case that when computing multiple outputs in a network (each hidden layer output is a different output right?), theanets may be using a different batch each time due to randomness in the input batch selection (when the input is a callable, it may keep calling the function each time with a different batch for each output separately). Also, gradient calculation may be doing something funny as well.

So if this is the case, it would also slow down the computation, since for each output computed, it would run forward pass from the input with a different batch.

Probably, I am wrong. But this may be worth investigating.

lmjohns3 commented 9 years ago

This is an interesting thought. I've spent a little time looking into it, but I'm not sure it applies to theanets because we onyl construct the outputs graph once. I will keep looking at it, though, since it would be nice to figure out what's up with this slow training issue.

I noticed one more "trick" in the currennt code: it looks like they bound the cross-entropy error to the range +/- 100: http://sourceforge.net/p/currennt/code/ci/master/tree/currennt_lib/src/layers/CePostOutputLayer.cu#l95 -- this would act like another form of gradient clipping, I think.

There might also be some differences in how the weights of the model are being initialized that could have a big impact on things. I'm about to check in a change to the LSTM forget biases that I just read about, maybe that will help.

kastnerkyle commented 9 years ago

Alex Graves uses this with various thresholds for different tasks - sometimes 10 or 100. I believe there is a PR for theano to make it an option of grad . Saw it on the mailing list today. On Jan 14, 2015 11:27 PM, "Leif Johnson" notifications@github.com wrote:

This is an interesting thought. I've spent a little time looking into it, but I'm not sure it applies to theanets because we onyl construct the outputs graph once. I will keep looking at it, though, since it would be nice to figure out what's up with this slow training issue.

I noticed one more "trick" in the currennt code: it looks like they bound the cross-entropy error to the range +/- 100: http://sourceforge.net/p/currennt/code/ci/master/tree/currennt_lib/src/layers/CePostOutputLayer.cu#l95 -- this would act like another form of gradient clipping, I think.

There might also be some differences in how the weights of the model are being initialized that could have a big impact on things. I'm about to check in a change to the LSTM forget biases that I just read about, maybe that will help.

— Reply to this email directly or view it on GitHub https://github.com/lmjohns3/theanets/pull/56#issuecomment-70038409.

hknerdgn commented 9 years ago

Nice discussions guys.

Some comments/ideas:

  1. In a backwards recurrent layer, I think the output should be reverted in time, since scan is supposed to produce the output in the backwards order when go_backwards is True. I saw this in the lasagne/nntools implementation of RNN/LSTM. Interestingly the network still learns something even if the output is not reversed, but it is not the best one can do since it uses output at another time to predict the target at another time along a sequence.
  2. The theano function used for learning (trainer.f_learn) is very complex right now with many outputs. May be there may be an option to have as few outputs as possible (e.g. no monitors) which may speed up compilation and runtime due to better optimization. I am not sure if it is possible to have no outputs at all (just have updates that depend on the gradient of the cost for example)
  3. last line of feedforward.py return self.predict(x).argmax(axis=1) could be return self.predict(x).argmax(axis=-1)
    to make it work for 3D sequence inputs, otherwise not correct for batched sequence data...
  4. Any plans to implement the mask idea for sequence data soon? Then we could directly compare with currennt and even load currennt network parameters into theanets and see if they both work the same.
hknerdgn commented 9 years ago

This comment from craffel https://github.com/benanne/Lasagne/issues/17#issuecomment-70204547 seems to verify my intuition about higher layers recomputing the earlier layer outputs again and again. If each layer goes all the way to the input (when it is a callable which picks data at random), each layer could be using different data for computing their outputs and we would be in for a mess!