apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.73k stars 6.81k forks source link

SKIP RNN is incorrect in LSTnet #10805

Open diyang opened 6 years ago

diyang commented 6 years ago

Thanks for your example and source code, it helps me a lot to develop my own LSTnet. However, I have noticed that there are some flaws in the implementation of Skip RNN of yours.

The following is your way of implementing Skip RNN

####################
# Skip-RNN Component
####################
stacked_rnn_cells = mx.rnn.SequentialRNNCell()
for i, recurrent_cell in enumerate(skipcells):
   stacked_rnn_cells.add(recurrent_cell)
   stacked_rnn_cells.add(mx.rnn.DropoutCell(dropout))
outputs, states = stacked_rnn_cells.unroll(length=q, inputs=cnn_reg_features, merge_outouts=False)

# Take output from cells p steps apart
p = int(seasonal_period / time_interval)
output_indices = list(range(0, q, p))
outputs.reverse()
skip_outputs = [outputs[i] for i in output_indices]
skip_rnn_features = mx.sym.concat(*skip_outputs, dim=1)

What I have noticed is that this way will not actually create a skip rnn, and what's this RNN doing is to select the hidden states regarding the last hour (the 24th hour) of every day in a week. If the sequence length is 24*7, the skip RNN output here is 7 layers of hidden states, and each layer of hidden states denoted the last hour of each and every 7 days.

According to the paper, what it should like is that input variants regarding every hour of any given day should pair with the hidden states regarding the same hour of the previous day. The outputs of this skip RNN should be the last 24 hours of the last day, meaning 24 layers of hidden states, and each layer of hidden states denoted every hours in the last day.

mx.rnn.SequentialRNNCell can not handle this type of recurrent, you might need to make your own native way to formuate this type of RNN.

QiXuanWang commented 6 years ago

Interesting, I'm working on LSTNet too. I rewrote the whole with gluon and yes, the RNNCell could not be used. Actually I tried to remove skip-rnn layer though both doesn't work very well for my problem.

diyang commented 6 years ago

@QiXuanWang I have used MxNet R to implement SKIP RNN You may find it in this function. https://github.com/diyang/deeplearning.mxnet/blob/master/LSTnet/src/lstnet_model.R

I used queue to contain the hidden states of 24 hours, then I will pop the queue head, and then push the newly yielded hidden state of current hour into the queue tail.

rnn.skip.unroll<-function(data, 
                     num.rnn.layer=1,
                     seq.len,
                     num.hidden,
                     seasonal.period,
                     dropout=0,
                     config="gru")
{
  param.cells <- list()
  last.states <- list()
  for( i in 1:num.rnn.layer){
    if(config == "gru"){
      param.cells[[i]] <- list(gates.i2h.weight = mx.symbol.Variable(paste0("l", i, ".gates.i2h.weight")),
                               gates.i2h.bias = mx.symbol.Variable(paste0("l", i, ".gates.i2h.bias")),
                               gates.h2h.weight = mx.symbol.Variable(paste0("l", i, ".gates.h2h.weight")),
                               gates.h2h.bias = mx.symbol.Variable(paste0("l", i, ".gates.h2h.bias")),

                               trans.i2h.weight = mx.symbol.Variable(paste0("l", i, ".trans.i2h.weight")),
                               trans.i2h.bias = mx.symbol.Variable(paste0("l", i, ".trans.i2h.bias")),
                               trans.h2h.weight = mx.symbol.Variable(paste0("l", i, ".trans.h2h.weight")),
                               trans.h2h.bias = mx.symbol.Variable(paste0("l", i, ".trans.h2h.bias")))
      state <- list(h=mx.symbol.Variable(paste0("l", i, ".gru.init.h")))
    }else{
      param.cells[[i]] <- list(i2h.weight = mx.symbol.Variable(paste0("l", i, ".i2h.weight")),
                               i2h.bias = mx.symbol.Variable(paste0("l", i, ".i2h.bias")),
                               h2h.weight = mx.symbol.Variable(paste0("l", i, ".h2h.weight")),
                               h2h.bias = mx.symbol.Variable(paste0("l", i, ".h2h.bias")))
      state <- list(c=mx.symbol.Variable(paste0("l", i, ".lstm.init.c")),
                    h=mx.symbol.Variable(paste0("l", i, ".lstm.init.h")))
    }
    last.states[[i]] <- state
  }

  data_seq_slice = mx.symbol.SliceChannel(data=data, num_outputs=seq.len, axis=2, squeeze_axis=1)

  last.hidden <- list()
  #it's a queue
  seasonal.states <- list()
  for (seqidx in 1:seq.len){
    hidden <- data_seq_slice[[seqidx]]
    # stack lstm
    if(seqidx <= seasonal.period){
      for (i in 1:num.rnn.layer){
        dropout <- ifelse(i==1, 0, dropout)
        prev.state <- last.states[[i]]

        if(config == "gru"){
          next.state <- gru.cell(num.hidden,
                                 indata = hidden,
                                 prev.state = prev.state,
                                 param = param.cells[[i]],
                                 seqidx = seqidx,
                                 layeridx = i,
                                 dropout = dropout)
        }else{
          next.state <- lstm.cell(num.hidden,
                                  indata = hidden,
                                  prev.state = prev.state,
                                  param = param.cells[[i]],
                                  seqidx = seqidx,
                                  layeridx = i,
                                  dropout = dropout)
        }
        hidden <- next.state$h
        last.states[[i]] <- next.state
      }
      seasonal.states <- c(seasonal.states, last.states)
    }else{
      for (i in 1:num.rnn.layer){
        dropout <- ifelse(i==1, 0, dropout)
        prev.state <- seasonal.states[[1]]
        seasonal.states <- seasonal.states[-1]
        if(config == "gru"){
          next.state <- gru.cell(num.hidden,
                                 indata = hidden,
                                 prev.state = prev.state,
                                 param = param.cells[[i]],
                                 seqidx = seqidx,
                                 layeridx = i,
                                 dropout = dropout)
        }else{
          next.state <- lstm.cell(num.hidden,
                                  indata = hidden,
                                  prev.state = prev.state,
                                  param = param.cells[[i]],
                                  seqidx = seqidx,
                                  layeridx = i,
                                  dropout = dropout)
        }
        hidden <- next.state$h
        last.states[[i]] <- next.state
      }
      seasonal.states <- c(seasonal.states, last.states)
    }

    # Aggeregate outputs from each timestep
    last.hidden <- c(last.hidden, hidden)
  }
  list.all <- list(outputs = last.hidden, last.states = last.states)

  return(list.all)
}
diyang commented 6 years ago

@QiXuanWang By the way, according to the paper, if your data is not periodic, or period is dynamic, then you shall choose the variation of LSTnet - LSTnet-Attn. Meaning use Temporal Attention Layer to replace SKIP RNN

roywei commented 6 years ago

@sandeep-krishnamurthy could you help to add label Example, RNN? Thanks!

QiXuanWang commented 6 years ago

@diyang thanks very much for LSTnet-Attn mention. My current problem actually could use this. Will try it out.

roywei commented 5 years ago

Hi @safrooze , could you contribute your Gluon implementation back as the official example for LSTNet? Since we are promoting Gluon now and we can't actually use SequentialRNNCell for Skip-RNN.

safrooze commented 5 years ago

@roywei Absolutely. I'll send a PR.

jonathandgough commented 5 years ago

@safrooze @roywei is the LSTnet-Attn method in Gluon yet? Thanks!

Mobealy commented 2 years ago

Hi All, anybody know what happened to the multivariate_time_series example? Or all the examples got moved?:

https://github.com/apache/incubator-mxnet/blob/master/example/multivariate_time_series