yunjey / pytorch-tutorial

PyTorch Tutorial for Deep Learning Researchers
MIT License
29.83k stars 8.09k forks source link

Input format of nn.LSTM class #122

Closed ani0075saha closed 6 years ago

ani0075saha commented 6 years ago

The pytorch official documentation mentions for torch.nn.LSTM "input of shape (seq_len, batch, input_size)" but in your recurrent_neural_network example I observed that the input size is [100, 28, 28] where sequence_length = 28 input_size = 28 batch_size = 100. Is this correct? Or should we transpose the tensor? I am confused.

yunjey commented 6 years ago

@ani0075 See here. You can feed the input of shape (batch, seq_len, input_size) by setting batch_first=True. You can use it as you like.