Open lukan94 opened 8 years ago
Do note that the way you use getParameters() is not a supported feature, since you intend on sharing the parameters between model
and model_cloned
, but you call getParameters on model
. getParameters() changes the memory locations of the parameters inside model
and preserves references to the same memory inside of the model on which this call is performed, it does not change the references in model_cloned
, thereby losing the sharing of parameters. So a better way to do getParameters()
would be as such:
param, gradparam = nn.Sequentila():add(model):add(model_cloned):getParameters()
Additionally I think you should also share the gradients between the models if you share the parameters, so you should probably use this instead:
model:clone('weight', 'bias','gradWeight','gradBias')
Also I have hit this problem a bunch of times myself, and I am glad I understand it now...
Hi,
I should have been more clear earlier. I was actually trying to focus the attention on the main issue, but since you've mentioned the above, I'll tell you exactly what I'm trying to do.
I basically want to build a siamese network using the pre-trained model mentioned above. I use the following code:
siamese = nn.ParallelTable()
siamese:add(model)
siamese:add(model:clone('weight', 'bias', 'gradWeight', 'gradBias'))
finalmodel = nn.Sequential()
finalmodel:add(nn.SplitTable(1))
finalmodel:add(siamese)
finalmodel:add(nn.PairwiseDistance(2))
parameters, gradParameters = finalmodel:getParameters() -- gives same error as above!
The main issue is that getParameters() is throwing the above mentioned error when dealing with the cloned model of a pre-trained network. When I don't use a pre-trained network and use a newly constructed model with the same architecture, it works fine.
I see, I can reproduce the error here: https://gist.github.com/JoostvDoorn/9b8e267305dac4452add0328a2f9c724 It seems one of the parameters is nil for the cloned model, but I don't know why this happens.
You could simply use SeqLSTM for now, which is faster anyway, and does not have this problem.
Yeah that's exactly what's happening. Some of the parameters of the cloned model are being assigned nil value. I believe this is happening in the below lines of code under torch/install/share/lua/5.1/dpnn/Module.lua:
-- recursive get all modules (modules, sharedclones, etc.)
local function recursiveGetModules(tbl)
for k,m in pairs(tbl) do
if torch.isTypeOf(m, 'nn.Module') then
if not m.dpnn_getParameters_found then
con:add(m)
m.dpnn_getParameters_found = true
recursiveGetModules(m)
end
elseif torch.type(m) == 'table' then
recursiveGetModules(m)
end
end
end
recursiveGetModules(self)
for i,m in ipairs(con.modules) do
m.dpnn_getParameters_found = nil
end
What I don't understand is why is this problem only occuring with a pre-trained model? When I build a new model with the same architecture and then clone it and use getParameters(), it works fine. I wanted to use peep-hole LSTMs and that's why I went with LSTM and not SeqLSTM, which uses FastLSTM. However, are you sure this problem does not occur with the use of SeqLSTM in the pre-trained model?
Could you try resetting the sharedClones in the LSTM before cloning the pretrained model?
You should probably do something like this:
model.modules[1].modules[1].modules[1].sharedClones = {model.modules[1].modules[1].modules[1].recurrentModule}
I tried. Doesn't work. Gives the same error
Just checked. SeqLSTM works! But still, any help would be appreciated in figuring out why peephole LSTM wrapped in a Sequencer does not work when cloning a pre-trained model.
I meet the same problem, maybe you can try to require cunn and cudnn.
Keep in mind that Element-Research/rnn was deprecated and is superseded by torch/rnn if you reinstalled Torch.
Hi,
I have a simple
LSTM
architecture as shown below :Now I have pre-trained this already and saved it. Next I load the model and perform a
clone()
on it as:If I perform
getParameters()
on the original model, it gets executed as expected. But if I want to do the same on the cloned model, it throws an error.It gives the following error :
Any idea regarding why this is happening ?
Thanks in advance !