yunjey / pytorch-tutorial

PyTorch Tutorial for Deep Learning Researchers
MIT License
29.79k stars 8.03k forks source link

RNN input size question #195

Closed OrangeC93 closed 4 years ago

OrangeC93 commented 4 years ago

I'm new to pytorch, Can anyone answer my question which confused me a lot:

In RNN tutorial

images are reshaped into (batch, seq_len, input_size)

images = images.reshape(-1, sequence_length, input_size)

But What I learned input dimensions should be (seq_len, batch, input_size)?

GatoY commented 4 years ago

Hi @OrangeC93,

self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

See "batch_first = True". But in default "batch_first = False".

That's the reason. U can refer to the source code of RNN cell. here

OrangeC93 commented 4 years ago

Hi @OrangeC93,

self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

See "batch_first = True". But in default "batch_first = False".

That's the reason. U can refer to the source code of RNN cell. here

wow~ I got it! Many thanks!