Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
941 stars 313 forks source link

Using ConcatTable in Recurrent #291

Closed ethanabrooks closed 8 years ago

ethanabrooks commented 8 years ago

Recurrent appears to mutate ConcatTable.modules into {}. As a result the following code:

require 'rnn'
require 'nn'

-- hyper-parameters
local batchSize = 3
local hiddenSize = 2
local x = torch.ones(batchSize, hiddenSize)

local feedback = nn.Sequential()
local concat = nn.ConcatTable()
feedback:add(concat)
concat:add(nn.Identity)
concat:add(nn.Identity)
feedback:add(nn.CAddTable())

local r = nn.Recurrent(
    nn.Identity(), -- start
    nn.Identity(), -- input
    feedback,
    nn.Identity() --transfer
)
print(r)
r:updateOutput(x)
r:updateOutput(x)
r:updateOutput(x)

throws this error:

tput: No value for $TERM and no -T specified
/Users/Ethan/torch/install/bin/luajit: /Users/Ethan/torch/install/share/lua/5.1/nn/Container.lua:67: 
In 1 module of nn.Sequential:
In 2 module of nn.ParallelTable:
In 1 module of nn.Sequential:
In 1 module of nn.ConcatTable:
attempt to call a nil value -- this, I believe refers to the modules of ConcatTable
stack traceback:
    [C]: in function 'xpcall'
    /Users/Ethan/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    /Users/Ethan/torch/install/share/lua/5.1/nn/ConcatTable.lua:11: in function </Users/Ethan/torch/install/share/lua/5.1/nn/ConcatTable.lua:9>
    [C]: in function 'xpcall'
    /Users/Ethan/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    /Users/Ethan/torch/install/share/lua/5.1/nn/Sequential.lua:44: in function </Users/Ethan/torch/install/share/lua/5.1/nn/Sequential.lua:41>
    [C]: in function 'xpcall'
    /Users/Ethan/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    ...s/Ethan/torch/install/share/lua/5.1/nn/ParallelTable.lua:12: in function <...s/Ethan/torch/install/share/lua/5.1/nn/ParallelTable.lua:10>
    [C]: in function 'xpcall'
    /Users/Ethan/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    /Users/Ethan/torch/install/share/lua/5.1/nn/Sequential.lua:44: in function 'updateOutput'
    /Users/Ethan/torch/install/share/lua/5.1/rnn/Recurrent.lua:75: in function 'updateOutput'
    issue.lua:32: in main chunk
    [C]: at 0x010923fbd0

nn.Recurrent {
  [{input(t), output(t-1)} -> (1) -> (2) -> (3) -> output(t)]
  (1):  {
    input(t)
      |`-> (t==0): nn.Identity
      |`-> (t~=0): nn.Identity
    output(t-1)
      |`-> nn.CAddTable
  }
  (2): nn.CAddTable
  (3): nn.ConcatTable {
    input
      |`-> (1): table: 0x0bffa7c0
      |`-> (2): table: 0x0bffa7c0
       ... -> output
  }
}

A few observations. Recurrent only throws the error on the third call. Also, feedback does not throw the error outside Recurrent. Thanks for our help!

nicholas-leonard commented 8 years ago

Use nn.Copy to force a copy:

local feedback = nn.Sequential()
local concat = nn.ConcatTable()
feedback:add(concat)
concat:add(nn.Copy(nil, nil, true))
concat:add(nn.Copy(nil, nil, true))
feedback:add(nn.CAddTable())
ethanabrooks commented 8 years ago

Thanks! That fixed it. Could you help me understand what's going on here? nn.Recurrent is a module that I plan to use extensively.

nicholas-leonard commented 8 years ago

The problem is that nn.Recurrent assumes that the output and input are different tensors. That is why you must explicitly copy input to output if the tensor is the same.

ethanabrooks commented 8 years ago

Thanks. That's very clarifying.