Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
939 stars 313 forks source link

Different ways to use LSTM #244

Closed LeonardKnuth closed 8 years ago

LeonardKnuth commented 8 years ago

Hi everyone,

I try to build a recurrent neural networks using nn.LSTM in the following two methods: one uses nn.Sequencer, and the other uses nn.ConcatTable. Although I fixed the input table and parameters of the two nets, their outputs and errs are different. I am confusing if there is a difference between this two implementations? If so, what is the difference? Thanks a lot.

The two different implementations:

require 'nn'
require 'rnn'
require 'torch'

-----Basic LSTM Module
local lstm_basic_para=nn.LSTM(100, 10, 2)
torch.save('lstm.t7', lstm_basic_para)

local lstm_basic_sqr=torch.load('lstm.t7')

------implementation by ConcatTable
local para_seq = nn.Sequential()
local concat_net_1 = nn.ConcatTable()
concat_net_1:add(nn.Sequential():add(nn.SelectTable(1)):add(lstm_basic_para))
concat_net_1:add(nn.SelectTable(2))
concat_net_1:add(nn.SelectTable(3))
para_seq:add(concat_net_1)
local concat_net_2 = nn.ConcatTable()
concat_net_2:add(nn.Sequential():add(nn.SelectTable(2)):add(lstm_basic_para))
concat_net_2:add(nn.SelectTable(3))
para_seq:add(concat_net_2)

para_seq:add(nn.Sequential():add(nn.SelectTable(2)):add(lstm_basic_para))
para_seq:training()

------implementation by nn.Sequencer
local sqr_seq = nn.Sequential()
local sequencer_net = nn.Sequencer(lstm_basic_sqr)
sequencer_net:remember('both')
sequencer_net:training()
sqr_seq:add(sequencer_net)
sqr_seq:add(nn.SelectTable(3))

local criterion= nn.MSECriterion()

-----input and output
local input_table ={torch.Tensor(100), torch.Tensor(100), torch.Tensor(100)}
local groundtruth = torch.Tensor(10)

for i = 1, 2 do   

   local sqr_output = sqr_seq:forward(input_table)
   local sqr_err = criterion:forward(sqr_output, groundtruth)

   print('sqr_seq output:')
   print(sqr_output)
   print('sqr_err:')
   print(sqr_err)

   local para_output = para_seq:forward(input_table)
   local para_err = criterion:forward(para_output, groundtruth)

   print('para_seq output:')
   print(para_output)
   print('para_err:')
   print(para_err)

   -----update parameters
   local para_gradCriterion = criterion:backward(para_output, groundtruth)
   para_seq:zeroGradParameters()
   para_seq:backward(input_table, para_gradCriterion)
   para_seq:updateParameters(0.05)

   local sqr_gradCriterion = criterion:backward(sqr_output, groundtruth)
   sqr_seq:zeroGradParameters()
   sqr_seq:backward(input_table, sqr_gradCriterion)
   sqr_seq:updateParameters(0.05)

end

The corresponding outputs:

sqr_seq output: 0.01 * 1.5458 2.8944 -2.6372 1.5481 1.8520 -0.7645 -2.3867 -0.4645 -3.4330 -1.0323 [torch.DoubleTensor of size 10] sqr_err: 0.00042896711636675

para_seq output: 0.01 * 1.5458 2.8944 -2.6372 1.5481 1.8520 -0.7645 -2.3867 -0.4645 -3.4330 -1.0323 [torch.DoubleTensor of size 10] para_err: 0.00042896711636675

sqr_seq output: 0.01 * 1.8912 3.3510 -3.0471 1.3787 2.2088 -0.5411 -2.8534 -0.4141 -4.0167 -1.1782 [torch.DoubleTensor of size 10] sqr_err: 0.00056998107217014

para_seq output: 0.01 * 1.8891 3.3471 -3.0448 1.3777 2.2072 -0.5398 -2.8506 -0.4141 -4.0118 -1.1765 [torch.DoubleTensor of size 10] para_err: 0.00056879631643435

After several trials, here is the conclusion and guess:

The forward of the two implementations are same when the sequencer method turns on remember('both'), but it seems that they have a slight different backwards. Is there any randomness in the backward? or they use the totally different backward methods?

Thanks.

nicholas-leonard commented 8 years ago

@LeonardKnuth The backward will not work with ConcatTable as AbstractRecurrent instances require these to be called in reverse order of forward. Basically, don't use ConcatTable :)

LeonardKnuth commented 8 years ago

@nicholas-leonard Thanks for your explanation. However, in my above program, each ConcatTable only contains one LSTM, and then use Sequential to connect all the ConcatTable. In that case, the order of forward and backward should be as same as that of AbstractRecurrent, is it right? Thanks.

nicholas-leonard commented 8 years ago

@LeonardKnuth Sorry I hadn't caught that. Ok then the order of forward/backward should be the same. I think then the problem is caused by the updateParameters. In the Concat implementation calling parameters() will return the LSTM's params and gradParams 3 times (once for each LSTM in the network). So the call to updateParameters will effectively add those gradients 3 times instead of just 1. You can alleviate this by using getParameters to obtain a consolidated tensor of params and gradParams which shouldn't contain duplicates. To update you call params:add(-lr, gradParams).

LeonardKnuth commented 8 years ago

@nicholas-leonard Thank you very much for your detailed explanation. One more thing, is there any difference in nature between adding gradients 3 times and once? Can we adjust the learning rate (e.g., 0.0001 for the concat method and 0.0003 for the sequencer method.) to make these two methods consistent? Thanks a lot.

nicholas-leonard commented 8 years ago

@LeonardKnuth That trick will only work if the LSTM parameters are the only parameters return by a call to parameters(). This seems to be the case for your so year that could make them consistent.