torch / nn

Other
1.34k stars 969 forks source link

crash in MM.backward after :clearState() #781

Closed cvondrick closed 8 years ago

cvondrick commented 8 years ago

We are experiencing a crash in the backward pass of MM after calling :clearState on the network.

/data/vision/torralba/commonsense/torch/distro/install/bin/luajit: ...commonsense/torch/distro/install/share/lua/5.1/nn/MM.lua:51: attempt to index a nil value
stack traceback:
        ...commonsense/torch/distro/install/share/lua/5.1/nn/MM.lua:51: in function 'updateGradInput'
        ...onsense/torch/distro/install/share/lua/5.1/nn/Module.lua:30: in function 'backward'
        debug_clearstate.lua:17: in main chunk

Minimal code to reproduce:

require 'nn'
require 'cunn'

mod = nn.MM()
mod:cuda()

a = torch.randn(10,5,5):cuda()
b = torch.randn(10,5,5):cuda()
grad = torch.randn(10,5,5):cuda()

out = mod:forward({a,b})
mod:backward({a,b}, grad)

mod:clearState()

out = mod:forward({a,b})
mod:backward({a,b}, grad)  --- CRASH!

If you change the nn.MM to something else, such as nn.CAddTable(), there is no crash.

The code works fine until :clearState() is called. The crash happens with/without cunn/cuda.

Thank you very much!

fmassa commented 8 years ago

The problem is in here https://github.com/torch/nn/blob/master/MM.lua#L51 gradInput is an empty table after clearState, and here it's supposed to be a table with two tensors. the simplest is to lazily initialize gradInput in the backward