Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
938 stars 314 forks source link

Calling :backward() more then rho times should error #400

Open achalddave opened 7 years ago

achalddave commented 7 years ago

Calling backward more than rho times on an nn.Recurrent module can lead to undesired/undocumented behavior, when really it should be explicitly guarded against and cause an error. See the following code, which

  1. sets rho = 1
  2. Calls forward() 4 times
  3. Calls backward() 3 times

After 4 calls to forward(), there are only 2 clones of the recurrent module in the sharedClones. Calling backward 2 times affects these 2 clones, but a third call to backward will create a clone of the recurrentModule and call backward on it.

require 'nn'
require 'rnn'
local _ = require 'moses'

local model = nn.Recurrent(nn.Identity(), nn.Identity(), nil, nil, 1 --[[rho]])

local input = torch.rand(1)

model:training()
print('Step 1')
model:forward(input)
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 2

print('Step 2')
model:forward(input)
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 2

print('Step 3')
model:forward(input)
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 2, 3

print('Step 4')
model:forward(input)
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 3, 4

print('Step 1 backward')
model:backward(input, input) -- Calls backward on module 4.
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 3, 4

print('Step 2 backward')
model:backward(input, input) -- Calls backward on module 3.
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 3, 4

print('Step 3 backward')
model:backward(input, input) -- Creates new module 2, and calls backward?!
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 2, 3, 4