hunkim / PyTorchZeroToAll

Simple PyTorch Tutorials Zero to ALL!
http://bit.ly/PyTorchZeroAll
3.89k stars 1.2k forks source link

Does `batch_first` really also work on the shape of RNN's hidden vector? #16

Open iamkissg opened 6 years ago

iamkissg commented 6 years ago

Hi, in your code for 12_1_rnn_basics, you mentioned that

(batch, num_layers * num_directions, hidden_size) for batch_first=True

However, in pytorch's docs, it just said:

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature)

So I got confused which one was right? Then I printed the shape of RNN's hidden vector after executing this line, which turned out to be [1, 3, 2].

So could you please tell me which one is right, and explain why the output is [1, 3, 2]?

Thank you very much.

hunkim commented 6 years ago

It's a bug.

This is correct.

(num_layers * num_directions, batch, hidden_size) for batch_first=True

Can you send me a PR?

On Fri, Dec 29, 2017 at 2:15 AM, Engine Chen notifications@github.com wrote:

Hi, in your code for 12_1_rnn_basics https://github.com/hunkim/PyTorchZeroToAll/blob/master/12_1_rnn_basics.py#L15, you mentioned that

(batch, num_layers * num_directions, hidden_size) for batch_first=True

However, in pytorch's docs http://pytorch.org/docs/master/nn.html?highlight=batch_first#rnn, it just said:

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature)

So I got confused which one was right? Then I printed the shape of RNN's hidden vector after executing this line https://github.com/hunkim/PyTorchZeroToAll/blob/59c86cc42b305807789291564501a667689fb812/12_1_rnn_basics.py#L45, which turned out to be [1, 3, 2].

So could you please tell me which one is right, and explain why the output is [1, 3, 2]?

Thank you very much.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/hunkim/PyTorchZeroToAll/issues/16, or mute the thread https://github.com/notifications/unsubscribe-auth/AA3DVx6zSHsq5bWoSLo2yDew_CVQtwBsks5tE8zGgaJpZM4ROf8K .