Closed shreejalt closed 3 years ago
Hi @shreejalt,
I remember spending quite some time on ensuring that the shapes are right but I might be wrong. Here's my explanation to what you pointed out:
The input_data
tensor is of shape bs * seq_len * input_dims
. So by looping through it like input_data[:, t, :]
we're iterating over the time window.
Then when feeding the LSTM like:
weighted_input = torch.mul(a_t, input_data[:, t, :].to(device)) # (bs * input_size)
self.lstm.flatten_parameters()
_, (h_t, c_t) = self.lstm(weighted_input.unsqueeze(0), (h_t, c_t))
the input weighted_input.unsqueeze(0)
is of the shape seq_len * bs * input_dims
and the hidden states and cell states of shape 1 * batch_size * hidden_siae
.
Hope that helps !
Thanks a lot @JulesBelveze
So, you extract the temporal batch_size * input_dims sequence from inp[:, t:, :] and then you do unsqueeze(0)
that will make the sequence length of 1 over the iteration of for loop. Gotcha
Thanks alot.
Also, one more doubt,
Is it necessary to write a for loop in case of simple Encoder/Decoder and not Attention one in forward
function? Because you are not passing the previously generated output/hidden state from the lstm cell to another but fixed input y_hist
.
So instead of for loop if I write
def forward(_, y_hist):
.
.
.
`outputs, (hn, cn) = self.lstm(y_hist, (h0, c0))`
Will it work?
@shreejalt yes you're right I'm also using loop in the vanilla case but you should be able to it that way (which would be simpler btw) :)
Feel free to open a PR if you have ideas on how to improve the repo !
@shreejalt closing this as your problems seem to be resolved 😄
Thanks @JulesBelveze for your intuitive work on LSTM Autoencoders. I had one doubt. IN your model.py file, you dont mention batch_first=True in your LSTM initialization in Encoder and Decoder class both. But in your forward function, you iterate like
inp[:, t, :] for t in range of (seqlen)
According to documentation in Pytorch, the dimensions to the input of LSTM should be(seqlen, batchsize, inputdims)
ifbatch_first
flag is False(Which is the default case)So aren't you iterating in batch dims rather than doing it in seqlen dims? Please correct me if I am wrong.
Shouldn't it be inp[t, :, :] for t in range(seqlen) ?
Or you are transposing the input matrix before passing?
Thanks