sgrvinod / a-PyTorch-Tutorial-to-Image-Captioning

Show, Attend, and Tell | a PyTorch Tutorial to Image Captioning
MIT License
2.75k stars 711 forks source link

why use LSTMCell not use LSTM directly #181

Open morestart opened 2 years ago

AndreiMoraru123 commented 1 year ago

@morestart, you probably already know the answer. However, in case anyone else was wondering. LSTM in pytorch is a multi layer network, that is why you can select the number of layers. LSTMCell, on the other hand, is just a single cell. The author uses the latter here because of the way the attention has to be computed at each step in the training process. With a multilayer LSTM you could not do that, as the layer connections and forward pass are hard coded.

thanhtvt commented 1 year ago

@AndreiMoraru123 so if I set the number of layers in LSTM as 2, is it the same as I build a 2-time for-loop with LSTMCell?

AndreiMoraru123 commented 1 year ago

@thanhtvt Exactly!

And this is precisely the example PyTorch provides in the docs:

If you take a look at the LSTM page:

rnn = nn.LSTM(10, 20, 2)   # (10 = input size, 20 = hidden size, 2 = this is the number of layers)
input = torch.randn(5, 3, 10)  # (5 = this is the sequence length, 3 = this is the batch size, 
#  10 = this is the last dimension, has to be equal to the input shape of the LSTM)
h0 = torch.randn(2, 3, 20) # (2 = here is the number of layers again, 3 = the batch size has to match,
#   20 = the hidden state has to match)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))  # the output here is going to be of size [5,3,20], just like the input

Then at the LSTMCell page, it's pretty much the same thing, but using a for loop:

rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
hx = torch.randn(3, 20) # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []
for i in range(input.size()[0]):
    hx, cx = rnn(input[i], (hx, cx))
    output.append(hx)
output = torch.stack(output, dim=0)   # output.size() will be [2,3,20], as you stacked the hx's [3,20] across the first dimension.