bentrevett / pytorch-seq2seq

Tutorials on implementing a few sequence-to-sequence (seq2seq) models with PyTorch and TorchText.
MIT License
5.37k stars 1.34k forks source link

Question about tutorial 1 and 2 Decoder #168

Closed djaekim closed 9 months ago

djaekim commented 3 years ago

Hello, I had a question about prediction = self.fc_out(output)

In the decoder in tutorial 2, why is the output = torch.cat((embedded.squeeze(0), hidden.squeeze(0), context.squeeze(0)), dim = 1) as opposed to output = torch.cat((embedded.squeeze(0), output.squeeze(0), context.squeeze(0)), dim = 1)?

In tutorial 2 text, it says image

Thank you!

bentrevett commented 3 years ago

When we have a sequence length of one, which we do when decoding, then output == hidden, as output is the hidden state from all time-steps, and the hidden is the hidden state from the final time-step, so if we have just a single time-step then the two are identical. So both of the code snippets in your question do the exact same thing.