Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
938 stars 313 forks source link

Unsuccessful training on multivariate timeseries toy dataset #113

Closed zencoding closed 8 years ago

zencoding commented 8 years ago

This is really not an issue, so please feel free to close it and I will try asking the questions in Google Groups or Stackoverflow

I have been playing with LSTM for multi variate time series, it is failing to learn for very simple toy example. My dataset is basically an alternating time series with even sequence as addition and odd sequence as multiplication

Sample input data

       1        2        3
       2        3        6
       3        4        7
       4        5       20
       5        6       11
       6        7       42
       7        8       15
       8        9       72
       9       10       19
      10       11      110
      11       12       23

I modified the recurrent-time-series.lua from rnn/examples to change the sequence dataset leaving all other code as is. The default dataset works fine but my dataset produces very high error rates. I tried all recommended settings like changing number of numbers, learning rate, etc, I even tried to clone the output of the forward pass as mentioned in some other forum post. Nothing seems to fix the error rates. Below is the code, it is identical to the standard example except for the sequence dataset, any help to fix will be greatly appreciated

-- Multi-variate time-series example

require 'rnn'

cmd = torch.CmdLine()
cmd:text()
cmd:text('Train a multivariate time-series model using RNN')
cmd:option('--rho', 5, 'maximum number of time steps for back-propagate through time (BPTT)')
cmd:option('--multiSize', 3, 'number of random variables as input and output')
cmd:option('--hiddenSize', 100, 'number of hidden units used at output of the recurrent layer')
cmd:option('--dataSize', 1000, 'total number of time-steps in dataset')
cmd:option('--batchSize', 50, 'number of training samples per batch')
cmd:option('--nIterations', 1000, 'max number of training iterations')
cmd:option('--learningRate', 0.0001, 'learning rate')
cmd:text()
local opt = cmd:parse(arg or {})

-- For simplicity, the multi-variate dataset in this example is independently distributed.
-- Toy dataset (task is to predict next vector, given previous vectors) following the normal distribution .
-- Generated by sampling a separate normal distribution for each random variable.
-- note: vX is used as both input X and output Y to save memory
local function evalPDF(vMean, vSigma, vX)
   for i=1,vMean:size(1) do
      local b = (vX[i]-vMean[i])/vSigma[i]
      vX[i] = math.exp(-b*b/2)/(vSigma[i]*math.sqrt(2*math.pi))
   end
   return vX
end

assert(opt.multiSize > 1, "Multi-variate time-series")

-- vBias = torch.randn(opt.multiSize)
-- vMean = torch.Tensor(opt.multiSize):fill(5)
-- vSigma = torch.linspace(1,opt.multiSize,opt.multiSize)
-- sequence = torch.Tensor(opt.dataSize, opt.multiSize)
--
-- j = 0
-- for i=1,opt.dataSize do
--   sequence[{i,{}}]:fill(j)
--   evalPDF(vMean, vSigma, sequence[{i,{}}])
--   sequence[{i,{}}]:add(vBias)
--   j = j + 1
--   if j>10 then j = 0 end
-- end
sequence = torch.Tensor(opt.dataSize,opt.multiSize)
for j = 1,opt.dataSize do
  if j%2 > 0 then
    sequence[j][1]=j
    sequence[j][2] = j +1
    sequence[j][3] = sequence[j][1] + sequence[j][2]
  end
  if j%2 == 0 then
    sequence[j][1]=j
    sequence[j][2] = j +1
    sequence[j][3] = sequence[j][1] * sequence[j][2]
  end
end
print('Sequence:'); print(sequence)

-- batch mode

offsets = torch.LongTensor(opt.batchSize):random(1,opt.dataSize)

-- RNN
r = nn.Recurrent(
   opt.hiddenSize, -- size of output
   nn.Linear(opt.multiSize, opt.hiddenSize), -- input layer
   nn.Linear(opt.hiddenSize, opt.hiddenSize), -- recurrent layer
   nn.Sigmoid(), -- transfer function
   opt.rho
)

rnn = nn.Sequential()
   :add(r)
   :add(nn.Linear(opt.hiddenSize, opt.multiSize))

criterion = nn.MSECriterion()

-- use Sequencer for better data handling
rnn = nn.Sequencer(rnn)

criterion = nn.SequencerCriterion(criterion)
print("Model :")
print(rnn)

-- train rnn model
minErr = opt.multiSize -- report min error
minK = 0
avgErrs = torch.Tensor(opt.nIterations):fill(0)
for k = 1, opt.nIterations do

   -- 1. create a sequence of rho time-steps

   local inputs, targets = {}, {}
  --  for step = 1, opt.rho do
  --     -- batch of inputs
  --     inputs[step] = inputs[step] or sequence.new()
  --     inputs[step]:index(sequence, 1, offsets)
  --     -- batch of targets
  --     offsets:add(1) -- increase indices by 1
  --     offsets[offsets:gt(opt.dataSize)] = 1
  --     targets[step] = targets[step] or sequence.new()
  --     targets[step]:index(sequence, 1, offsets)
  --  end
  for step = 1, opt.rho do
     -- batch of inputs
     inputs[step] = inputs[step] or sequence.new()
     inputs[step]:index(sequence, 1, offsets)
     -- batch of targets
     offsets:add(1) -- increase indices by 1
     offsets[offsets:gt(opt.dataSize)] = 1
     targets[step] = targets[step] or sequence.new()
     targets[step]:index(sequence, 1, offsets)
  end

   -- 2. forward sequence through rnn

   local outputs = rnn:forward(inputs)
   local err = criterion:forward(outputs, targets)

   -- report errors

   print('Iter: ' .. k .. '   Err: ' .. err)
   avgErrs[k] = err
   if avgErrs[k] < minErr then
      minErr = avgErrs[k]
      minK = k
   end

   -- 3. backward sequence through rnn (i.e. backprop through time)

   rnn:zeroGradParameters()

   local gradOutputs = criterion:backward(outputs, targets)
   local gradInputs = rnn:backward(inputs, gradOutputs)

   -- 4. updates parameters

   rnn:updateParameters(opt.learningRate)
end

print('min err: ' .. minErr .. ' on iteration ' .. minK)
rracinskij commented 8 years ago

I'm testing this example with similar toy data, maybe some of my first findings will be useful for you:

zencoding commented 8 years ago

@rracinskij thanks for the suggestions, unfortunately I have already tried all of them and nothing seems to work. I have been reading stuff at other blogs about intricacies of LSTM and trying out different implementations of it but nothing seems to work for my simple toy example. I am sure I am missing something very basic, I just don't what :)

Things I have already tried

  1. Train with input have only the first two numbers and target as the third
  2. Train with batch (using Sequencer) and without batch. Interestingly the batch method produces output from forward pass with same value for every value of input[step], I am debugging to find out why, looks like some reference issue
  3. Tried it with different optimizer and hyper parameters (rho, hidden size, learning rate, momentum,etc)
  4. Tried with changing epochs
  5. Training with only addition (instead of alternating multiplication and addition) produces lower error but it is still very high. The model just doesn't seem to be learning

I have been at this for few days already, I am going to try build LSTM from scratch, may be using char-rnn as baseline

rracinskij commented 8 years ago

Here is a quick modification of your example - https://github.com/rracinskij/rnntest01/blob/master/rnn-multiple-time-series01.lua It seems to be learning, although there is still much room for improvement. Please note that target is selected from the sequence by increasing the offset, so the output of row i should be in the row i+1. I'll be able to get deeper into the issue not earlier than tomorrow.

zencoding commented 8 years ago

Thanks for quick modification, I tested it and it indeed seems to be learning. I reviewed your code, the changes you made are reducing the input values (to be between 0 and 1), offsets are linear (instead of random), change to LSTM and change dimensions of inputs/output. If these changes make the model learn then it means that for recurrent models, the method of providing input and range are very important.

It seems counter intuitive since in most machine learning setup, you typically randomize the inputs but looks like recurrent nets require sequential inputs. Also I am not sure if the errors are any better than before, the variance has reduced but it is still significant considering the scale of the input.

I will continue to experiment more with it and wait for any other suggestions you have.

rracinskij commented 8 years ago

Offsets were changed from random to linear just for convenience, it shouldn't affect the performance. Inputs were set between 0 and 1 to avoid eventual saturation problems, if any. Important detail was to select the output from the next sequence. During similar experiments with an univariate model I noticed that using a sequence of previous values instead of scalars as input at every step remarkably increases the prediction accuracy, but it would require a 2D model for the multivariate case.

nicholas-leonard commented 8 years ago

@rracinskij Could you submit a PR adding this script as an example to the examples folder?

rracinskij commented 8 years ago

@nicholas-leonard I submitted a PR, tried to make my initial script a bit cleaner. I have also noticed that training fails while using sigmoid as activation function, have no idea why.

rracinskij commented 8 years ago

I spent some time with testing the run time series example with a simple toy dataset (https://github.com/rracinskij/rnntest01/blob/master/rnn-multiple-time-series-tests-description.md)

Results are rather disappointing, as it is not easy to train even a simplest 0.01, 0.02, 0.03 sequence. I couldn't get any meaningful results with a Fibonacci-style sequence (0.01, 0.01, 0.02, 0.03, 0.05...). Surprisingly, a linear combination of both above mentioned sequences together with its values in previous steps trains quite easily.

So at the moment it looks like that the model setup should be in some way different. The model is also very sensitive to the number of hidden units.

nicholas-leonard commented 8 years ago

@rracinskij try :

r = nn.Recurrent(
   opt.hiddenSize,
   nn.Linear(opt.inputLast-opt.inputFirst+1, opt.hiddenSize), -- input layer
   nn.Linear(opt.hiddenSize, opt.hiddenSize), -- recurrent layer
   nn.Sigmoid(), -- transfer function
   opt.rho
)
rnn1 = nn.Sequential()
   :add(r)
   :add(nn.Linear(opt.hiddenSize, opt.outputLast-opt.outputFirst+1)) 
   :add(nn.Sigmoid()) -- if the target is also between 0 and 1

rnn2 = nn.Sequential()
    :add(nn.LSTM(opt.inputLast-opt.inputFirst+1, opt.hiddenSize))
    :add(nn.Linear(opt.hiddenSize, opt.outputLast-opt.outputFirst+1)) 
    :add(nn.Sigmoid()) -- if the target is also between 0 and 1
rracinskij commented 8 years ago

@nicholas-leonard using nn.Sigmoid() instead of nn.Tanh() gives results - outputs look like averages of targets (or inputs?). For example, having 0.31, 0.32, ... 0.38 as targets gives 0.1860, 0.1863, 0.1865, 0.1868, 0.1870, 0.1872, 0.1875, 0.1877 as outputs, while switching to Tanh() gives 0.3118, 0.3159, 0.3199, 0.3239, 0.3279, 0.3318, 0.3358, 0.3397. That applies both to linear and LSTM/GRU models.

And both model still can't crack the Fibonacci-style 0.01, 0.01, 0.02, 0.03, 0.05, 0.08 sequence.

NB: according to Yoshua Bengio (https://www.quora.com/How-can-one-apply-deep-learning-to-time-series-forecasting), RNN is the best tool for time series prediction.

nicholas-leonard commented 8 years ago

@rracinskij Another thing you could try is to quantize the output and target into classes. So for fibonacci you could have classes 1,2,....100 each representing integers : 1,2,....100. So then you can use a LookupTable and SoftMax and train it just like a language model.

In that same line of reasoning, to allow for greater numbers, you could create a binary representation of the inputs and targets. So for 8 bits, that is 16 classes (the first 0 and 1, the second 0 and 1) and so on. Throw a stacked LSTM at it and see what happens :)

rracinskij commented 8 years ago

@nicholas-leonard Switching to quantized targets and LookupTable+SoftMax allows even the simple recurrent model from the examples folder to learn the Fibonacci sequence easily. So I’m wondering why it performs much better that continuous models?

And it is still unclear why nn.Sigmoid() is much worse than nn.Tanh() in the original setup.

nicholas-leonard commented 8 years ago

Why discrete space is much easier to learn than continuous space?

The way I see it is as such. I you divide the space between 0 and 1 into 100 discrete ranges where each gets its own output/input unit and target class then you are effectively assigning 100 neurons, each with its own input/output weights. On the other hand, for continuous space, your space between 0 and 1 only has one neuron and associated input/output weights. So much less parameters. It also means that for continuous space the hidden neurons/parameters have a more difficult task.

Why Sigmoid is word than Tanh? I don't have a good idea why.

rracinskij commented 8 years ago

Your approach is really smart. Could you please explain the role of LookupTable in this setup? Shall it work with multivariate time series?

nicholas-leonard commented 8 years ago

LookupTable is for input layer. It basically maps each quantized value to an embedding vector. The embeddings are learned. Yes it could work with multivariate time series.

rracinskij commented 8 years ago

Adapting the model from the simple recurrent example for the quantized multivariate time series case is a bit trickier (if possible at all) than it might seem at first:

local r = nn.Recurrent(
   hiddenSize,
   nn.LookupTable(nIndex, hiddenSize),
   nn.Linear(hiddenSize, hiddenSize),
   nn.Tanh(), 
   rho
)
local rnn = nn.Sequential()
   :add(r)    
   :add(nn.Linear(hiddenSize, nIndex))
   :add(nn.LogSoftMax())
  1. It is unclear how to use LookupTable for more than one input.
  2. There might be more than one output to be classified.
nicholas-leonard commented 8 years ago

@rracinskij Okay so lets say you have inputSize=4 variables that you want to quantize. You can either use the same LookupTable for all of them, or a different one for each. So your input would be of size batchSize x inputSize. The lookup table can handle this :

batchSize = 2
inputSize = 4
nIndex = 10
hiddenSize = 3
input = torch.LongTensor(batchSize, inputSize):random(1,nIndex)
lookup = nn.LookupTable(nIndex, hiddenSize)
output = lookup:forward(input)
print(output)

(1,.,.) = 
  0.2128  1.0470  1.4555
 -0.2879  1.8431 -0.4295
 -0.3924  2.5976 -0.9050
 -0.3217 -0.5281 -0.9261

(2,.,.) = 
 -0.5678 -0.2925  0.7227
  0.2128  1.0470  1.4555
 -0.3217 -0.5281 -0.9261
 -0.3217 -0.5281 -0.9261
[torch.DoubleTensor of size 2x4x3]

So the output of this has size batchSize x nVar x hiddenSize. To feed this into linear, you need to collapse the last two dimensions :

col = nn.Collapse(2)
out2 = col:forward(out)
print(out2)
Columns 1 to 10
 0.2128  1.0470  1.4555 -0.2879  1.8431 -0.4295 -0.3924  2.5976 -0.9050 -0.3217
-0.5678 -0.2925  0.7227  0.2128  1.0470  1.4555 -0.3217 -0.5281 -0.9261 -0.3217

Columns 11 to 12
-0.5281 -0.9261
-0.5281 -0.9261
[torch.DoubleTensor of size 2x12]

For the outputs, you will need a multi-class classification criterion. You can build one by combining ParallelCriterion with ClassNLLCriterions.

rracinskij commented 8 years ago

@nicholas-leonard Thanks a lot!