Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
941 stars 313 forks source link

Multiple forward calls followed with multiple backward calls #396

Closed mohitsharma0690 closed 7 years ago

mohitsharma0690 commented 7 years ago

I want to implement a seq2seq architecture wherein I can do multiple forward calls, average the loss and then backprop each individually. I've been trying to modify the seq2seq architecture for this using nn.SeqLSTM. here is some code which I think is the culprit

self.decoder:add(nn.Sequencer(nn.MaskZero(nn.Linear(1024, self.vocab_size),1)))
self.decoder:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(),1)))

This doesn't seem possible with the current architecture because the above nn.Sequencer() seem to store timestep state internally which gets reset to 0 after 1 backward pass. Hence it crashes when I try to backprop the next time.

Is there someway I can achieve this using rnn?

PS: I don't think this is a github issue but I didn't find any mailing list for RNN library. Let me know if there is a better medium to ask this question.