Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
941 stars 313 forks source link

Error when using BiSequencer with trimZero (code to reproduce) #376

Closed Cadene closed 7 years ago

Cadene commented 7 years ago

Hello,

I am encountering the following error :

/home/cadene/torch-pascal/install/bin/luajit: /home/cadene/.luarocks/share/lua/5.1/nn/Container.lua:67: 
In 2 module of nn.Sequential:
In 1 module of nn.Sequential:
In 1 module of nn.ConcatTable:
In 1 module of nn.Sequential:
In 2 module of nn.ConcatTable:
In 1 module of nn.Sequential:
In 1 module of nn.ParallelTable:
In 1 module of nn.Sequential:
In 1 module of nn.ConcatTable:
/home/cadene/.luarocks/share/lua/5.1/torch/Tensor.lua:322: incorrect size: only supporting singleton expansion (size=1)
stack traceback:
    [C]: in function 'error'
    /home/cadene/.luarocks/share/lua/5.1/torch/Tensor.lua:322: in function 'expandAs'
    ...adene/torch-pascal/install/share/lua/5.1/rnn/Dropout.lua:47: in function <...adene/torch-pascal/install/share/lua/5.1/rnn/Dropout.lua:28>
    [C]: in function 'xpcall'
    /home/cadene/.luarocks/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    /home/cadene/.luarocks/share/lua/5.1/nn/ConcatTable.lua:11: in function </home/cadene/.luarocks/share/lua/5.1/nn/ConcatTable.lua:9>
    [C]: in function 'xpcall'
    /home/cadene/.luarocks/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    /home/cadene/.luarocks/share/lua/5.1/nn/Sequential.lua:44: in function </home/cadene/.luarocks/share/lua/5.1/nn/Sequential.lua:41>
    [C]: in function 'xpcall'
    ...
    /home/cadene/.luarocks/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    /home/cadene/.luarocks/share/lua/5.1/nn/Sequential.lua:44: in function 'updateOutput'
    ...ne/torch-pascal/install/share/lua/5.1/dpnn/Decorator.lua:11: in function <...ne/torch-pascal/install/share/lua/5.1/dpnn/Decorator.lua:10>
    [C]: in function 'xpcall'
    /home/cadene/.luarocks/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    /home/cadene/.luarocks/share/lua/5.1/nn/Sequential.lua:44: in function 'forward'
    issue_bgru.lua:31: in main chunk
    [C]: in function 'dofile'
    ...rch-pascal/install/lib/luarocks/rocks/trepl/scm-1/bin/th:150: in main chunk
    [C]: at 0x00405b60

How to solve this ? It obviously comes from the use of a bayesian GRU as a backward rnn in BiSequencer.

Please, reproduce the error with this code:

require 'nn'
require 'rnn'

local batchSize = 20
local seqSize = 26
local inputSize = 620
local outputSize = 1200
local dropout = 0.25

local gru_fwd = nn.GRU(inputSize, outputSize)
local gru_bwd
if dropout == 0 then
   gru_bwd = nn.GRU(inputSize, outputSize)
else
   gru_bwd = nn.GRU(inputSize, outputSize, false, dropout, true)
end
gru_fwd:trimZero(1)
gru_bwd:trimZero(1)

local net = nn.Sequential()
net:add(nn.SplitTable(2))
net:add(nn.BiSequencer(gru_fwd, gru_bwd))
net:add(nn.SelectTable(-1))

local inputs = torch.randn(batchSize, seqSize, inputSize)

inputs[1][1]:zero()

local outputs = net:forward(inputs)
jnhwkim commented 7 years ago

@Cadene Thank you for reporting the issue. I sent a pull request to resolve the error. However, I should note that if zero paddings are left-aligned (backward rnn gets flipped inputs), the output of TrimZero may be not your intended one. Since a zero input vector initializes hidden state of GRU, not preserving the previous hidden state, in the current implementation.

Cadene commented 7 years ago

Thank you so much.

If I am understanding it right, you are saying that InverseTable (the module in BiSequencer) shifts the left padding to the right and trimZero doesn't remove the right padding. Thus, the outputs will not be correct.

What would be the best way to remove the zero padding on the right ?

jnhwkim commented 7 years ago

@Cadene The best way for that issue will be updating TrimZero module to keep the hidden state when zero vectors come in. An ad-hoc remedy is up to your application. Please refer to the below example.

require 'nn'
require 'rnn'

local batchSize = 20
local seqSize = 26
local inputSize = 620
local outputSize = 1200
local dropout = 0.25

local gru_fwd = nn.GRU(inputSize, outputSize)
local gru_bwd
if dropout == 0 then
   gru_bwd = nn.GRU(inputSize, outputSize)
else
   gru_bwd = nn.GRU(inputSize, outputSize, false, dropout, true)
end
gru_fwd:trimZero(1)
gru_bwd:trimZero(1)

local net = nn.Sequential()
net:add(nn.SplitTable(2))
net:add(nn.BiSequencer(gru_fwd, gru_bwd))

local inputs = torch.randn(batchSize, seqSize, inputSize)

inputs[1][1]:zero()
inputs[2][1]:zero()
inputs[1][2]:zero()
outputs = net:forward(inputs)
print(outputs)

for i=1,seqSize do
   print(outputs[i]:sub(1,2,1,3), outputs[i]:sub(1,2,outputSize+1,outputSize+5))
end

will output:

{
  1 : DoubleTensor - size: 20x2400
  2 : DoubleTensor - size: 20x2400
  3 : DoubleTensor - size: 20x2400
  4 : DoubleTensor - size: 20x2400
  5 : DoubleTensor - size: 20x2400
  6 : DoubleTensor - size: 20x2400
  7 : DoubleTensor - size: 20x2400
  8 : DoubleTensor - size: 20x2400
  9 : DoubleTensor - size: 20x2400
  10 : DoubleTensor - size: 20x2400
  11 : DoubleTensor - size: 20x2400
  12 : DoubleTensor - size: 20x2400
  13 : DoubleTensor - size: 20x2400
  14 : DoubleTensor - size: 20x2400
  15 : DoubleTensor - size: 20x2400
  16 : DoubleTensor - size: 20x2400
  17 : DoubleTensor - size: 20x2400
  18 : DoubleTensor - size: 20x2400
  19 : DoubleTensor - size: 20x2400
  20 : DoubleTensor - size: 20x2400
  21 : DoubleTensor - size: 20x2400
  22 : DoubleTensor - size: 20x2400
  23 : DoubleTensor - size: 20x2400
  24 : DoubleTensor - size: 20x2400
  25 : DoubleTensor - size: 20x2400
  26 : DoubleTensor - size: 20x2400
}
 0  0  0
 0  0  0
[torch.DoubleTensor of size 2x3]

 0  0  0  0  0
 0  0  0  0  0
[torch.DoubleTensor of size 2x5]

 0.0000  0.0000  0.0000
-0.2314  0.2889  0.1110
[torch.DoubleTensor of size 2x3]

 0.0000  0.0000  0.0000  0.0000  0.0000
 0.2005  0.0745 -0.5170 -0.5513  0.1038
[torch.DoubleTensor of size 2x5]

-0.3682  0.2846  0.4928
-0.0459 -0.0644 -0.1465
[torch.DoubleTensor of size 2x3]

-0.0419  0.2299  0.1828 -0.2765  0.0665
-0.1278  0.0979 -0.4719 -0.2073 -0.1182
[torch.DoubleTensor of size 2x5]

-0.6310 -0.3065  0.4660
-0.3260  0.1424  0.1357
[torch.DoubleTensor of size 2x3]

-0.1298  0.2234 -0.3248 -0.0538 -0.0608
-0.1118 -0.1232 -0.3420 -0.1803 -0.0018
[torch.DoubleTensor of size 2x5]

-0.5874  0.2692 -0.0054
-0.3141  0.4560 -0.1736
[torch.DoubleTensor of size 2x3]

-0.1084  0.0590 -0.1988 -0.0014 -0.3763
 0.2608 -0.3695 -0.1406 -0.0943  0.1540
[torch.DoubleTensor of size 2x5]

-0.1993  0.4975 -0.0387
-0.6763  0.6320  0.1011
[torch.DoubleTensor of size 2x3]

 0.3780  0.0864 -0.5836  0.0947 -0.1529
 0.1496 -0.2904  0.0476 -0.5972 -0.1439
[torch.DoubleTensor of size 2x5]

-0.1433 -0.0558  0.6019
-0.3933  0.7040 -0.1899
[torch.DoubleTensor of size 2x3]

 0.2434  0.0799 -0.3362 -0.3551 -0.1262
 0.1483 -0.3163  0.0456 -0.3795 -0.3136
[torch.DoubleTensor of size 2x5]

-0.1033  0.1587  0.2700
-0.5352 -0.0013  0.3395
[torch.DoubleTensor of size 2x3]

 0.4260  0.0395  0.1887  0.2243  0.1099
 0.1495  0.1180 -0.1659 -0.7148 -0.1288
[torch.DoubleTensor of size 2x5]

-0.3407 -0.0160  0.3832
-0.4963 -0.2790  0.2117
[torch.DoubleTensor of size 2x3]

 0.3621  0.0185  0.1800  0.2579  0.2518
 0.1930  0.2976 -0.1816 -0.2651  0.3301
[torch.DoubleTensor of size 2x5]

 0.3250 -0.1291  0.3249
-0.4013  0.0726  0.3880
[torch.DoubleTensor of size 2x3]

-0.2130 -0.2124  0.1872 -0.3821 -0.3469
-0.5527  0.2327 -0.5262 -0.1274 -0.0059
[torch.DoubleTensor of size 2x5]

 0.5586 -0.2415  0.5686
-0.0006 -0.2407  0.5380
[torch.DoubleTensor of size 2x3]

-0.3237 -0.5344  0.1779 -0.2382 -0.0358
-0.4356  0.0809 -0.6418 -0.2181  0.1829
[torch.DoubleTensor of size 2x5]

-0.4458 -0.1198  0.0943
 0.5822  0.0415  0.5501
[torch.DoubleTensor of size 2x3]

-0.2935 -0.2521  0.5369  0.1733 -0.1888
-0.0270 -0.0589 -0.2644  0.0524 -0.1389
[torch.DoubleTensor of size 2x5]

-0.3021 -0.3249  0.1086
 0.1490  0.1119  0.3938
[torch.DoubleTensor of size 2x3]

-0.3738  0.0554  0.4747 -0.2750 -0.1973
 0.4034  0.0160 -0.5589  0.2649 -0.1786
[torch.DoubleTensor of size 2x5]

 0.1880 -0.4116  0.1566
-0.0892  0.5089  0.6448
[torch.DoubleTensor of size 2x3]

-0.3624  0.0344  0.2751 -0.0957 -0.3211
 0.4781  0.0767 -0.8000  0.0850 -0.0172
[torch.DoubleTensor of size 2x5]

-0.1199 -0.0583  0.6793
-0.4225  0.0372  0.3179
[torch.DoubleTensor of size 2x3]

 0.3011 -0.1914  0.1877  0.0703 -0.4776
 0.3669  0.4287 -0.3780  0.3919  0.0931
[torch.DoubleTensor of size 2x5]

 0.1869  0.1210  0.4755
-0.2576  0.2814  0.4576
[torch.DoubleTensor of size 2x3]

-0.1762 -0.7379  0.0657 -0.7108  0.1077
-0.3548  0.3957 -0.3825  0.4042 -0.1735
[torch.DoubleTensor of size 2x5]

-0.3007  0.3032  0.3034
-0.2902  0.2836 -0.0533
[torch.DoubleTensor of size 2x3]

-0.2070 -0.4167  0.2865 -0.3633 -0.3970
 0.0053  0.3027 -0.6343  0.2082 -0.1921
[torch.DoubleTensor of size 2x5]

 0.1410  0.1667  0.0919
-0.0561  0.2451 -0.5299
[torch.DoubleTensor of size 2x3]

-0.1984 -0.1154  0.1830  0.0989 -0.5274
-0.0118  0.4943 -0.3121 -0.0503  0.3910
[torch.DoubleTensor of size 2x5]

-0.2592  0.0511 -0.2117
 0.0030  0.3341 -0.2256
[torch.DoubleTensor of size 2x3]

 0.0874  0.0286 -0.4705 -0.0910 -0.5454
-0.5317 -0.0399 -0.3728  0.2094  0.1260
[torch.DoubleTensor of size 2x5]

 0.0385 -0.3266 -0.5036
 0.0914 -0.1621  0.3965
[torch.DoubleTensor of size 2x3]

 0.6463  0.3097 -0.2800  0.2159 -0.4020
-0.6361 -0.1606 -0.2040  0.8354  0.0967
[torch.DoubleTensor of size 2x5]

-0.2623 -0.6139 -0.0957
-0.1591  0.2838  0.2031
[torch.DoubleTensor of size 2x3]

-0.2910  0.5071 -0.0108  0.2422 -0.5573
-0.0710  0.1230 -0.3950  0.7169 -0.3956
[torch.DoubleTensor of size 2x5]

-0.3162 -0.5456 -0.4040
-0.2262  0.3883 -0.0922
[torch.DoubleTensor of size 2x3]

-0.3001  0.1368 -0.1088 -0.0572 -0.4437
-0.0884 -0.4149  0.0622  0.6426 -0.2400
[torch.DoubleTensor of size 2x5]

-0.1150 -0.2338 -0.1643
 0.1071  0.4287  0.0405
[torch.DoubleTensor of size 2x3]

-0.1686  0.2319  0.1944  0.2425 -0.0391
 0.1140 -0.1558 -0.0551  0.2319  0.0755
[torch.DoubleTensor of size 2x5]

 0.3728  0.0051 -0.1302
 0.2224  0.0767  0.6261
[torch.DoubleTensor of size 2x3]

 0.0225  0.3192  0.2712  0.2872  0.1136
 0.0709 -0.4577  0.5015 -0.0403  0.2527
[torch.DoubleTensor of size 2x5]

-0.0232 -0.1762 -0.3414
-0.0700 -0.0839  0.6069
[torch.DoubleTensor of size 2x3]

 0.3170 -0.2756  0.3263 -0.2625  0.1348
-0.4522 -0.3917  0.4454  0.2736 -0.2093
[torch.DoubleTensor of size 2x5]

-0.3033 -0.4015 -0.3508
 0.4013 -0.3485  0.4995
[torch.DoubleTensor of size 2x3]

-0.1359 -0.0783  0.2307 -0.1892  0.3136
-0.2148 -0.0115  0.0193  0.3637  0.0200
[torch.DoubleTensor of size 2x5]