JulesBelveze / time-series-autoencoder

PyTorch Dual-Attention LSTM-Autoencoder For Multivariate Time Series
Apache License 2.0
614 stars 63 forks source link

Regarding input dimensions #9

Closed shreejalt closed 3 years ago

shreejalt commented 3 years ago

Thanks @JulesBelveze for your intuitive work on LSTM Autoencoders. I had one doubt. IN your model.py file, you dont mention batch_first=True in your LSTM initialization in Encoder and Decoder class both. But in your forward function, you iterate like inp[:, t, :] for t in range of (seqlen) According to documentation in Pytorch, the dimensions to the input of LSTM should be (seqlen, batchsize, inputdims) if batch_first flag is False(Which is the default case)

So aren't you iterating in batch dims rather than doing it in seqlen dims? Please correct me if I am wrong.

Shouldn't it be inp[t, :, :] for t in range(seqlen) ?

Or you are transposing the input matrix before passing?

Thanks

JulesBelveze commented 3 years ago

Hi @shreejalt, I remember spending quite some time on ensuring that the shapes are right but I might be wrong. Here's my explanation to what you pointed out: The input_data tensor is of shape bs * seq_len * input_dims. So by looping through it like input_data[:, t, :] we're iterating over the time window. Then when feeding the LSTM like:

weighted_input = torch.mul(a_t, input_data[:, t, :].to(device))  # (bs * input_size)
self.lstm.flatten_parameters()
_, (h_t, c_t) = self.lstm(weighted_input.unsqueeze(0), (h_t, c_t))

the input weighted_input.unsqueeze(0) is of the shape seq_len * bs * input_dims and the hidden states and cell states of shape 1 * batch_size * hidden_siae.

Hope that helps !

shreejalt commented 3 years ago

Thanks a lot @JulesBelveze So, you extract the temporal batch_size * input_dims sequence from inp[:, t:, :] and then you do unsqueeze(0) that will make the sequence length of 1 over the iteration of for loop. Gotcha

Thanks alot.

Also, one more doubt, Is it necessary to write a for loop in case of simple Encoder/Decoder and not Attention one in forward function? Because you are not passing the previously generated output/hidden state from the lstm cell to another but fixed input y_hist. So instead of for loop if I write

def forward(_, y_hist):
.
.
.
`outputs, (hn, cn) = self.lstm(y_hist, (h0, c0))` 

Will it work?

JulesBelveze commented 3 years ago

@shreejalt yes you're right I'm also using loop in the vanilla case but you should be able to it that way (which would be simpler btw) :)

Feel free to open a PR if you have ideas on how to improve the repo !

JulesBelveze commented 3 years ago

@shreejalt closing this as your problems seem to be resolved 😄