lmjohns3 / theanets

Neural network toolkit for Python
http://theanets.rtfd.org
MIT License
328 stars 74 forks source link

Preserve recurrent hidden state across training examples #132

Open feynmanliang opened 7 years ago

feynmanliang commented 7 years ago

If I understand correctly, recurrent network training in theanets currently uses self.h_0 as the initial state (which defaults to None, implying all zeros) for every training example. This goes well with the Text.classifier_batches implementation, which returns batches of (previous N characters, N+1 character) input/output pairs drawn randomly from the training text.

The problem is that this is a stronger assumption than truncated BPTT, which just clips the gradients after a fixed timespan. Here, not only are the gradients clipped, but the hidden states are zeroed.

I would like to preserve the hidden state across training examples. Then, I can iterate across the training text in sequence and learn dynamics which do not assume zero initial state at the start of every sequence. The state should be able to be reset during iter_train (i.e. when an epoch has elapsed). This is what's done in torch-rnn (see https://github.com/jcjohnson/torch-rnn/blob/master/train.lua#L176).

As a consequence of this, I am getting training/validation cross-entropy losses of 0.5/10 respectively in theanets while the equivalent model in torch-rnn is getting 0.09/0.7.

Have I understood theanets' implementation correctly? Is it possible to accomplish what I am suggesting in theanet with what's currently there?

lmjohns3 commented 7 years ago

There is an option to use "true" BPTT with each recurrent layer in theanets (see https://github.com/lmjohns3/theanets/blob/master/theanets/layers/recurrent.py#L61):

net = theanets.recurrent.Classifier([
    20, dict(form='rnn', size=20, bptt_limit=5), 10])

Then if you feed this model sequences of length 10, say, the gradients are only propagated back through the last 5 time steps, but the entire forward pass through all 10 time steps is performed.

I'd always thought of truncated BPTT purely as an optimization trick for reducing computation time (fewer backward pass steps = faster), but it does make a sort of sense that the model itself might improve with truncated BPTT since the zero initial hidden state is probably not correct, so the model would take a few time steps to get into the correct dynamics for the input, and thus the gradients at the beginning wouldn't be correct. I'd be curious whether just using this in your model permits the losses to improve?

Even if it does improve, however, the computation time issue is still there. As far as I've thought about it, theanets doesn't have an easy mechanism for capturing the hidden state of a network at some time step and then resetting it later. I think you might be able to do something like this by calling feed_forward() with your training data and capturing the hidden state, then passing that in with the next optimization step? But I haven't had any time to try this out myself, and even this approach would cost 2x the forward passes.

feynmanliang commented 7 years ago

Does using bptt_limit preserve the hidden state across epochs, or is it reset at the end of the epoch? Ideally I'd like to reset the hidden state every time a start or stop delimiter (e.g. <S> or </S>) is seen, but that might be asking for too much with varying sequence sizes and batching >.<