Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
938 stars 314 forks source link

Sampling from Sequence to Sequence Network. #258

Open khushigupta opened 8 years ago

khushigupta commented 8 years ago

I was trying to sample decoded responses from the sequence to sequence network.(https://github.com/Element-Research/rnn/blob/master/examples/encoder-decoder-coupling-seqLSTM.lua)

Can you tell me if the code is correct?

require 'rnn'
require 'hdf5'
require 'Dataloader'
local checkpoint = require 'checkpoints'

local dl = dataloader()
local opt = {}
opt.learningRate = 0.1
opt.hiddenSize = 100
opt.vocabSize = 15307
opt.seqLen = 15
opt.niter = 1000
opt.maxSampleLength = 20

local function forwardConnect(encLSTM, decLSTM)
   decLSTM.userPrevOutput = nn.rnn.recursiveCopy(decLSTM.userPrevOutput, encLSTM.outputs[opt.seqLen])
   decLSTM.userPrevCell = nn.rnn.recursiveCopy(decLSTM.userPrevCell, encLSTM.cells[opt.seqLen])
end

local function backwardConnect(encLSTM, decLSTM)
   encLSTM.userNextGradCell = nn.rnn.recursiveCopy(encLSTM.userNextGradCell, decLSTM.userGradPrevCell)
   encLSTM.gradPrevOutput = nn.rnn.recursiveCopy(encLSTM.gradPrevOutput, decLSTM.userGradPrevOutput)
end

model, optimState, epoch = checkpoint.load('EncDec')

local enc = model.enc
local dec = model.dec
local encLSTM = model.encLSTM
local decLSTM = model.decLSTM

local encInSeq, decInSeq, decOutSeq = dl:nextBatch(dl, 'test')
local decInSeq = torch.zeros(encInSeq:size(1), 1)
local decOut = torch.zeros(encInSeq:size(1), 1)

local finalOutput = torch.Tensor()
local buffer = torch.DoubleTensor()

enc:evaluate()
dec:evaluate()

local encOut = enc:forward(encInSeq)
forwardConnect(encLSTM, decLSTM)
decOut = dec:forward(decInSeq)

for i=1, opt.maxSampleLength do
    buffer:resize(decOut[1]:size()):copy(decOut[1])
    buffer:exp()
    val, idx = torch.max(buffer, 2)
    decInSeq:resize(idx:size()):copy(idx)
    decOut = dec:forward(decInSeq)
    finalOutput.cat(decOut[1])
end
nicholas-leonard commented 8 years ago

@khushigupta You can use torch.max or you can use torch.multinomial as in evaluate-rnnlm.lua script.

I don't think finalOutput.cat(decOut[1]) will work. Maybe :

local finalOutput = torch.Tensor(opt.maxSampleLength)
...
for i=1, opt.maxSampleLength do
   ...
   finalOutput[i] = decOut[1]
khushigupta commented 8 years ago

Thanks. The code worked after your suggestion. I had another doubt.

  1. Just to confirm - for batched implementations, the tensor sizes are going to be seqLen X batchsize X dim, correct? (http://nbviewer.jupyter.org/github/CS287/Lectures/blob/gh-pages/notebooks/ElementRNNTutorial.ipynb)
  2. Is there an advantage of using torch.multinomial over torch.max?
nicholas-leonard commented 8 years ago
  1. Yes.
  2. torch.max is deterministic so the generated sequence will always be the same. torch.multinomial is stochastic so it will be different.