torch / nngraph

Graph Computation for nn
Other
299 stars 96 forks source link

Recurrence with split #109

Open nicholas-leonard opened 8 years ago

nicholas-leonard commented 8 years ago

Hi guys,

This crashes :

require 'nngraph';

n1 = 3
n2 = 4
n3 = 3

x1 = nn.Identity()()
x23 = nn.Identity()()
x2,x3 = x23:split(2)
z  = nn.JoinTable(1)({x1,x2,x3})
y1 = nn.Linear(n1+n2+n3,n2)(z)
y2 = nn.Linear(n1+n2+n3,n3)(z)
m = nn.gModule({x1,x23},{y1,y2})

input = {torch.randn(n1), {torch.randn(n2), torch.randn(n3)}}
output = m:forward(input)
print(output)
print(input)
input[2] = output
print(input)
m:forward(input)

The error:

/usr/local/bin/luajit: /usr/local/share/lua/5.1/nngraph/gmodule.lua:314: split(2) cannot split 0 outputs
stack traceback:
    [C]: in function 'error'
    /usr/local/share/lua/5.1/nngraph/gmodule.lua:314: in function 'neteval'
    /usr/local/share/lua/5.1/nngraph/gmodule.lua:346: in function 'forward'
    issues/issue172.lua:22: in main chunk
    [C]: in function 'dofile'
    /usr/local/lib/luarocks/rocks/trepl/scm-1/bin/th:131: in main chunk
    [C]: at 0x00405e60

Basically, the issue happens when feeding back outputs of a previous forward as inputs to the next forward to a gModule using split.

fidlej commented 8 years ago

It is dangerous to feed the output as the next input. The network can zero its output before reading the input. I suggest to use a deep copy of the output as the next input.

vbkbmqj commented 7 years ago

Hi , I met the similar problem as abrove when test rnn in evaluate mode. the second time I call mRNN:forward , it just crashes. But in training mode, it's all OK. can you help me?

require 'rnn' require 'nngraph' th = torch

inputSize,hiddenSize,outputSize = 5,5,5

local mX = nn.Identity()() local mS = nn.Identity()()

local mH,mA = (mS):split(2)

local mAN = mA - nn.Sigmoid()

local mHN = { mH, mX - nn.Linear(inputSize, hiddenSize), }

mCore = nn.gModule({mX,mS},{mHN,mAN})

mRNN = nn.Recurrence(mCore,{{hiddenSize},{hiddenSize}},1)

mRNN:evaluate() ------------------------------------ out = mRNN:forward(th.randn(inputSize)) out = mRNN:forward(th.randn(inputSize)) ------- this will crash out = mRNN:forward(th.randn(inputSize))