Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
939 stars 313 forks source link

Possible bug on save model #353

Open sudongqi opened 7 years ago

sudongqi commented 7 years ago

Update: The bug is caused by Serial option (light, medium and heavy), non-serial option resolved the bug. Further testing suggested that the bug might be related to the "remember" attribute not be able to turn on from a serial model.

@nicholas-leonard This problem has bothered me for 2 weeks now. The loss from the training indicate the learning is happening, but whenever I load the pre-trained model using another script, the sampling from the model (seq2seq using SeqLSTM) return gibberish.

So I did a test on an overfitting example, and the sampling still fails when loading from another script.

It turned out if I insert this piece of code right after the training script, the sampling work again. So it must be the error from the model saving function (or maybe the saving code I wrote is wrong? I have tried both Serial and the non-Serial option for saving.)

save file code:

--save model 
file={}
file.enc = nn.Serial(enc)
file.dec = nn.Serial(dec)

file.enc:mediumSerial()
file.dec:mediumSerial()

file.enc_lstm = enc_lstm
file.dec_lstm = dec_lstm
torch.save('model.t7',file)
print('save file')

model definition:

-- Encoder
local enc = nn.Sequential()
enc:add(nn.LookupTableMaskZero(opt.vocabSize, opt.hiddenSize))
enc_lstm = {}
for i=1,opt.numLayers do
    enc_lstm[i] = nn.SeqLSTM(opt.hiddenSize, opt.hiddenSize)
    enc_lstm[i]:maskZero()
    enc:add(enc_lstm[i])

end
enc:add(nn.Select(1, -1))

-- Decoder
local dec = nn.Sequential()
dec:add(nn.LookupTableMaskZero(opt.vocabSize, opt.hiddenSize))
dec_lstm = {}
for i=1,opt.numLayers do
    dec_lstm[i] = nn.SeqLSTM(opt.hiddenSize, opt.hiddenSize)
    dec_lstm[i]:maskZero()
    dec:add(dec_lstm[i])
end
dec:add(nn.Sequencer(nn.MaskZero(nn.Linear(opt.hiddenSize, opt.vocabSize), 1)))
dec:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(), 1)))
local criterion = nn.SequencerCriterion(nn.MaskZeroCriterion(nn.ClassNLLCriterion(),1))