Open diyang opened 6 years ago
Interesting, I'm working on LSTNet too. I rewrote the whole with gluon and yes, the RNNCell could not be used. Actually I tried to remove skip-rnn layer though both doesn't work very well for my problem.
@QiXuanWang I have used MxNet R to implement SKIP RNN You may find it in this function. https://github.com/diyang/deeplearning.mxnet/blob/master/LSTnet/src/lstnet_model.R
I used queue to contain the hidden states of 24 hours, then I will pop the queue head, and then push the newly yielded hidden state of current hour into the queue tail.
rnn.skip.unroll<-function(data,
num.rnn.layer=1,
seq.len,
num.hidden,
seasonal.period,
dropout=0,
config="gru")
{
param.cells <- list()
last.states <- list()
for( i in 1:num.rnn.layer){
if(config == "gru"){
param.cells[[i]] <- list(gates.i2h.weight = mx.symbol.Variable(paste0("l", i, ".gates.i2h.weight")),
gates.i2h.bias = mx.symbol.Variable(paste0("l", i, ".gates.i2h.bias")),
gates.h2h.weight = mx.symbol.Variable(paste0("l", i, ".gates.h2h.weight")),
gates.h2h.bias = mx.symbol.Variable(paste0("l", i, ".gates.h2h.bias")),
trans.i2h.weight = mx.symbol.Variable(paste0("l", i, ".trans.i2h.weight")),
trans.i2h.bias = mx.symbol.Variable(paste0("l", i, ".trans.i2h.bias")),
trans.h2h.weight = mx.symbol.Variable(paste0("l", i, ".trans.h2h.weight")),
trans.h2h.bias = mx.symbol.Variable(paste0("l", i, ".trans.h2h.bias")))
state <- list(h=mx.symbol.Variable(paste0("l", i, ".gru.init.h")))
}else{
param.cells[[i]] <- list(i2h.weight = mx.symbol.Variable(paste0("l", i, ".i2h.weight")),
i2h.bias = mx.symbol.Variable(paste0("l", i, ".i2h.bias")),
h2h.weight = mx.symbol.Variable(paste0("l", i, ".h2h.weight")),
h2h.bias = mx.symbol.Variable(paste0("l", i, ".h2h.bias")))
state <- list(c=mx.symbol.Variable(paste0("l", i, ".lstm.init.c")),
h=mx.symbol.Variable(paste0("l", i, ".lstm.init.h")))
}
last.states[[i]] <- state
}
data_seq_slice = mx.symbol.SliceChannel(data=data, num_outputs=seq.len, axis=2, squeeze_axis=1)
last.hidden <- list()
#it's a queue
seasonal.states <- list()
for (seqidx in 1:seq.len){
hidden <- data_seq_slice[[seqidx]]
# stack lstm
if(seqidx <= seasonal.period){
for (i in 1:num.rnn.layer){
dropout <- ifelse(i==1, 0, dropout)
prev.state <- last.states[[i]]
if(config == "gru"){
next.state <- gru.cell(num.hidden,
indata = hidden,
prev.state = prev.state,
param = param.cells[[i]],
seqidx = seqidx,
layeridx = i,
dropout = dropout)
}else{
next.state <- lstm.cell(num.hidden,
indata = hidden,
prev.state = prev.state,
param = param.cells[[i]],
seqidx = seqidx,
layeridx = i,
dropout = dropout)
}
hidden <- next.state$h
last.states[[i]] <- next.state
}
seasonal.states <- c(seasonal.states, last.states)
}else{
for (i in 1:num.rnn.layer){
dropout <- ifelse(i==1, 0, dropout)
prev.state <- seasonal.states[[1]]
seasonal.states <- seasonal.states[-1]
if(config == "gru"){
next.state <- gru.cell(num.hidden,
indata = hidden,
prev.state = prev.state,
param = param.cells[[i]],
seqidx = seqidx,
layeridx = i,
dropout = dropout)
}else{
next.state <- lstm.cell(num.hidden,
indata = hidden,
prev.state = prev.state,
param = param.cells[[i]],
seqidx = seqidx,
layeridx = i,
dropout = dropout)
}
hidden <- next.state$h
last.states[[i]] <- next.state
}
seasonal.states <- c(seasonal.states, last.states)
}
# Aggeregate outputs from each timestep
last.hidden <- c(last.hidden, hidden)
}
list.all <- list(outputs = last.hidden, last.states = last.states)
return(list.all)
}
@QiXuanWang By the way, according to the paper, if your data is not periodic, or period is dynamic, then you shall choose the variation of LSTnet - LSTnet-Attn. Meaning use Temporal Attention Layer to replace SKIP RNN
@sandeep-krishnamurthy could you help to add label Example, RNN? Thanks!
@diyang thanks very much for LSTnet-Attn mention. My current problem actually could use this. Will try it out.
Hi @safrooze , could you contribute your Gluon implementation back as the official example for LSTNet? Since we are promoting Gluon now and we can't actually use SequentialRNNCell for Skip-RNN.
@roywei Absolutely. I'll send a PR.
@safrooze @roywei is the LSTnet-Attn method in Gluon yet? Thanks!
Hi All, anybody know what happened to the multivariate_time_series example? Or all the examples got moved?:
https://github.com/apache/incubator-mxnet/blob/master/example/multivariate_time_series
Thanks for your example and source code, it helps me a lot to develop my own LSTnet. However, I have noticed that there are some flaws in the implementation of Skip RNN of yours.
The following is your way of implementing Skip RNN
What I have noticed is that this way will not actually create a skip rnn, and what's this RNN doing is to select the hidden states regarding the last hour (the 24th hour) of every day in a week. If the sequence length is 24*7, the skip RNN output here is 7 layers of hidden states, and each layer of hidden states denoted the last hour of each and every 7 days.
According to the paper, what it should like is that input variants regarding every hour of any given day should pair with the hidden states regarding the same hour of the previous day. The outputs of this skip RNN should be the last 24 hours of the last day, meaning 24 layers of hidden states, and each layer of hidden states denoted every hours in the last day.
mx.rnn.SequentialRNNCell can not handle this type of recurrent, you might need to make your own native way to formuate this type of RNN.