Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
941 stars 313 forks source link

[Clarification] Recurrence module description #393

Closed hashbangCoder closed 7 years ago

hashbangCoder commented 7 years ago

Hi,

The description in README is

Instead, it only manages a single recurrentModule, which should output a Tensor or table : output(t) given an input table : {input(t), output(t-1)}

To be clear, this should work when I create an arbitrary module with output as a table of tensors? Because, it doesn't appear to work for in this case. In particular, during the step==0 forward call, the output is returned a single tensor instead of table of tensors.

I'm trying to create a modified LSTM cell with an additional output - {cell, hidden, new_output} and I'm implementing this using nn.Recurrence. Is this currently possible? Any suggestions welcome. Thanks!

hashbangCoder commented 7 years ago

Okay, I've figured out the Recurrence format for multiple outputs. The README wasn't fully clear on this but going through the code helped.

If you have two outputs :

rm = custom_rm(10,20)   -- custom recurrent module
batch-size = 8
o1 = 20   --hidden output size
o2 = 30  -- another output
rnn = nn.Recurrence(rm,{{20},{30}},1)

inp = torch.randn(batch_size,10)
rnn:forward()

But now if the output is a table of the form {out1, out2}, you cannot wrap it inside a Sequencer. Because it expects only a single output from the rm (tensor/table) of shape time_steps*batch*feats. I'm still trying to figure out how to modify it accordingly to make both the forward and backward calls work

One possible way is to concatenate all outputs into one tensor and break it up into its constituents inside the rm

hashbangCoder commented 7 years ago

Alright. Concatenating all outputs into one tensor seems to work i.e. not throw errors. In my case the outputs are next_c, next_h and third_output (same dimension), concatenated into one big tensor. I then break these up inside my recurrent module graph and use them.

But I'm unsure about the BPTT correctness. I'll train my model and see if this works. Closing this for now.