Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
939 stars 313 forks source link

ConcatTable problem #372

Open shamangary opened 7 years ago

shamangary commented 7 years ago

I only consider the GRU as a single component of the nn.ConcatTable(). The input is a lua table such as {128(batchsize)x16(num sequence)x512(feature dim)} However, I cannot see what I do wrong in the following model. Please help me out. Thx!

net_score = nn.Sequential() net_score:add(nn.JoinTable(2)) net_score:add(nn.Reshape(16512)) net_score:add(nn.Linear(16512,10))


- Case.1: Not working

net_combine = nn.Sequential() cat_out = nn.ConcatTable() net_out_1 = nn.Sequential() net_out_1:add(net_gru) net_out_1:add(net_score) net_out_2 = nn.Sequential() net_out_2:add(net_gru) net_out_2:add(net_score) cat_out:add(net_out_1) cat_out:add(net_out_2) net_combine:add(cat_out) net_combine:add(nn.SelectTable(1))


- Error

...hamangary/torch/install/share/lua/5.1/nn/ConcatTable.lua:46: table size mismatch: 1 ~= 2 stack traceback: [C]: in function 'error' ...hamangary/torch/install/share/lua/5.1/nn/ConcatTable.lua:46: in function <...hamangary/torch/install/share/lua/5.1/nn/ConcatTable.lua:30> [C]: in function 'xpcall' .../shamangary/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors' ...shamangary/torch/install/share/lua/5.1/nn/Sequential.lua:58: in function 'updateGradInput' /home/shamangary/torch/install/share/lua/5.1/rnn/GRU.lua:188: in function '_updateGradInput' ...ry/torch/install/share/lua/5.1/rnn/AbstractRecurrent.lua:59: in function <...ry/torch/install/share/lua/5.1/rnn/AbstractRecurrent.lua:54> [C]: in function 'xpcall' .../shamangary/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors' ...shamangary/torch/install/share/lua/5.1/nn/Sequential.lua:58: in function 'updateGradInput' ... [C]: in function 'xpcall' .../shamangary/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors' ...shamangary/torch/install/share/lua/5.1/nn/Sequential.lua:84: in function 'backward' ./TYY_train.lua:120: in function 'opfunc' /home/shamangary/torch/install/share/lua/5.1/optim/sgd.lua:44: in function 'sgd' ./TYY_train.lua:135: in function 'train' TYY_main.lua:150: in main chunk [C]: in function 'dofile' ...gary/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:145: in main chunk [C]: at 0x00406670


- Case.2: Working

net_combine = nn.Sequential() cat_out = nn.ConcatTable() net_out_1 = nn.Sequential() --net_out_1:add(net_gru) net_out_1:add(net_score) net_out_2 = nn.Sequential() --net_out_2:add(net_gru) net_out_2:add(net_score) cat_out:add(net_out_1) cat_out:add(net_out_2) net_combine:add(cat_out) net_combine:add(nn.SelectTable(1))

- Case.3: Working

net_combine = nn.Sequential() cat_out = nn.ConcatTable() net_out_1 = nn.Sequential() net_out_1:add(net_gru) net_out_1:add(net_score) cat_out:add(net_out_1) net_combine:add(cat_out) net_combine:add(nn.SelectTable(1))