Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
941 stars 313 forks source link

Sampling from RNN #304

Open yashmaverick opened 8 years ago

yashmaverick commented 8 years ago

@nicholas-leonard I am trying to predict the next number in a sequence, where each sequence is of length say 5.

For example: input is {1,2,3,4,5} target is {2,3,4,5,6}

Training set has 1000 such sequences Validation set has 100 sequences

Model is as shown below: SeqLen = 5 rho = 5 -- no .of steps BPTT batchSize = 1 hiddenSize = 20 inputSize = 1 outputSize = 1 no_sampling = 10

model = nn.Sequential() :add(nn.Sequencer(nn.FastLSTM(inputSize,hiddenSize))) :add(nn.Sequencer(nn.Linear(hiddenSize, outputSize))) :add(nn.Sequencer(nn.ReLU()))

While inference, how to do sampling from the model ??

I wish to sample from the model 10 times (say 10 trials).

While sampling at first time, inputs are {t1,t2,t3,t4,t5} and true output is say {t2,t3,t4,t5,t6} and if the model predicts {t2',t3',t4',t5',t6'}

Next time when I sample, what will be my inputs?

Case1: inputs {t2,t3,t4,t5,t6'} or

Case2: {t2',t3',t4',t5',t6'}

Case3: only {t6'}, if I go on sampling indefinitely like this, is there a chance that predictions after rho trail here (5th trial) are same ??

But in either of Case1 and Case2, to sample for 5th time, my inputs will be completely predictions i.e {t6',t7',t8',t9',t10'}. The issue is only with sampling for first four trials during sampling.

Also, will it be good if I treat this problem as 'Sequence to One' prediction, where during training the inputs are {1,2,3,4,5} and target is {6} ??

nicholas-leonard commented 8 years ago

@yashmaverick If you want to sample from the model, you should feed inputs one at a time. So suppose I want to condition the model on t1, and generate a sequence of n samples, I can do something like :

rnn:evaluate()
local input = t1
local samples = {}
for i=1,n do
   local output = rnn:forward( {input} )
   table.insert(samples, output[1])
   input = output[1]
end 
print("generated sequence: ", sample)

Does this make sense to you?

Also, will it be good if I treat this problem as 'Sequence to One' prediction, where during training the inputs are {1,2,3,4,5} and target is {6} ??

No I think for training it is best to have, like you say, inputs = {t1, t2, ...} and targets= {t2, t3, ...}. That way you gradients from the output at each time-step.

yashmaverick commented 8 years ago

@nicholas-leonard Thanks..! In the above sample code can I use rnn:remember(eval) in the for loop i.e while sampling ? The outputs with and without using rnn:remember(eval) are different.

nicholas-leonard commented 8 years ago

@yashmaverick Yes use rnn:remember('eval')!

hashbangCoder commented 8 years ago

@nicholas-leonard @yashmaverick Should rnn:remember('eval') be used inside the loop? Because in #247 it is used only once outside the loop and it works fine. I tried using it inside the loop and the RNN samples the same word everytime. Using it outside leads to different sample at each time step

hashbangCoder commented 8 years ago

@nicholas-leonard I'm having interesting interactions between remember() and forget(). I'm building a captioning model and during training, every 1000 or so iterations, I sample the RNN on a test image. This is a high-level overview of my code -

<train code above>
if iter%1000 == 0 then
    evaluate_rnn(rnn,test_image)
end

[test code snippet]
function evaluate_rnn(rnn,image)
    rnn:evaluate()   -- for the dropout modules 
    rnn:forget()
    rnn:remember('eval')
    for i=1,maxSample do
        <sample output and feed it back>
    end
    rnn:forget()
end

If I leave out the first forget() call, my network samples the same word each time. If I move the remember('eval') inside the loop, it samples the same word again. Running it as shown above (mostly) produces different samples at each time step. Although I haven't trained it enough for enough variations.

Could you maybe tell if the remember call should be placed inside or outside? Am I missing out on something else? Thanks!