torch / nn

Other
1.34k stars 968 forks source link

Problem backpropagating from LookupTable to JoinTable #640

Open boknilev opened 8 years ago

boknilev commented 8 years ago

There seems to be a problem when trying to backpropagate from nn.LookupTable to nn.JoinTable. Here's a minimal working example:

model = nn.Sequential()
model:add(nn.JoinTable(1))
model:add(nn.LookupTable(10,2))
input = {torch.Tensor{1,2,3}, torch.Tensor{4,5,6}}
out = model:forward(input)
model:backward(input, out)

This fails with the following error:

/path/to/JoinTable.lua:59: bad argument #1 to 'narrow' (out of range at path/to/torch/pkg/torch/lib/TH/generic/THTensor.c:349)

The problem seems to be coming from LookupTable:updateGradInput returning a torch.DoubleTensor with no dimension. Then JoinTable:updateGradInput is failing when trying to do narrow on this no tensor with no dimension.

Any ideas what might be going on?

I'm getting this on Ubuntu 14.04 with an nn version that I just updated today.

fmassa commented 8 years ago

Have a look at https://github.com/torch/nn/issues/568#issuecomment-172016604

boknilev commented 8 years ago

Thanks, this works for the minimal example. However, in reality these layers are embedded in a larger network and before them I have a ConcatTable. When I apply the fix to ConcatTable as well, I get other errors, and I think it influences other modules in ConcatTable. In fact, overriding ConcatTable:updateGradInput doesn't seem like the right way to go.

Is there a better way to deal with the problem? Shouldn't this be fixed in LookupTable without needing to override previous module's functions?

boknilev commented 8 years ago

@fmassa I ended up restructuring my network so I don't have to deal with this problem. But perhaps it's worth thinking of another solution in case other people run into a similar problem when they embed a LookupTable after other modules.

kbullaughey commented 8 years ago

Here's another example of what I believe is the same problem. I can't seem to have a nn.Reshape upstream of nn.LookupTable. Getting the data into the right shape (while not employing parameters) seems like a fairly logical thing to want to do, particularly since nn.LookupTable is quite strict about what shapes it will accept.

require 'nn'

x = torch.Tensor({1,2,3,4}):reshape(2,2,1)

-- Because x is a 3D tensor, we want to reshape it:
node1 = nn.Reshape(2,2)
node2 = nn.LookupTable(4,5)

-- However, this doesn't work because for some reason we can't have
-- nn.Reshape upstream of nn.LookupTable
node1:forward(x)
node2:forward(node1.output)
gradInput = torch.Tensor():resizeAs(node2.output):random():mul(0.01)
node2:backward(node1.output, gradInput)
node1:backward(x, node2.gradInput)

-- If we reshape it first, circumventing node1, it works.
xr = x:reshape(2,2)
node2:forward(xr)
gradInput = torch.Tensor():resizeAs(node2.output):random():mul(0.01)
node2:backward(xr, gradInput)

-- Or if instead of reshaping we just split the table along the last dimension
-- and selecting the resulting table, we achieve the same effect as the
-- above nn.Reshape, but for some reason this works, in contract to nn.Reshape,
-- which failed.
node1a = nn.SplitTable(3)
node1b = nn.SelectTable(1)

-- No error
node1a:forward(x)
node1b:forward(node1a.output)
node2:forward(node1b.output)
gradInput = torch.Tensor():resizeAs(node2.output):random():mul(0.01)
node2:backward(node1b.output, gradInput)
node1b:backward(node1a.output, node2.gradInput)
node1a:backward(x, node1b.gradInput)
colesbury commented 8 years ago

How about the following?

node1 = nn.Reshape(2,2)
node1.updateGradInput = nn.Module.updateGradInput
kbullaughey commented 8 years ago

Hm...that works. Seems a bit hacky though. I wonder if this should be the standard behavior when updateGradInput receives a dimensionless gradient during backpropagation? Or does it make more sense to handle it on a case-by-case basis, depending on network topology?