IBM / pytorch-seq2seq

An open source framework for seq2seq models in PyTorch.
https://ibm.github.io/pytorch-seq2seq/public/index.html
Apache License 2.0
1.5k stars 376 forks source link

Doubt on "pytorch-seq2seq/seq2seq/models/EncoderRNN.py" #159

Closed caozhen-alex closed 6 years ago

caozhen-alex commented 6 years ago

if self.variable_lengths: embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) output, hidden = self.rnn(embedded) if self.variable_lengths: output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) Hi, Why here goes two if self.variable_lengths?

And this code doesn't identify the h0, c0, is that means they default to zero?

Looking forward to your response. @kylegao91

pskrunner14 commented 6 years ago

@caozhen-alex we use nn.utils.rnn.pack_padded_sequence here when the input sequences are of variable lengths and we need to feed it to an RNN but you need to be able to get the output at the right time step. Without packing, you have to unroll your RNN to a fixed length which will get you a fixed length of output. The desired output should have different length, so you have to mask them by yourself.

For example:

>>> import torch
>>> from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
>>> x = torch.LongTensor([[1, 2, 3], [7, 8, 0]])
>>> x = x.view(2, 3, 1)
>>> x
tensor([[[1],
         [2],
         [3]],

        [[7],
         [8],
         [0]]])
>>> lens = [3, 2]
>>> y = pack_padded_sequence(x, lens, batch_first=True)
>>> y       # tensor after packing 
PackedSequence(data=tensor([[1],
        [7],
        [2],
        [8],
        [3]]), batch_sizes=tensor([2, 2, 1]))
>>> z = pad_packed_sequence(y, batch_first=True)
>>> z        # the original tensor
(tensor([[[1],
         [2],
         [3]],

        [[7],
         [8],
         [0]]]), tensor([3, 2]))

And as for the h0, c0, the LSTM handles hidden and cell states on it's own, which means the hidden state of the LSTM is actually (h0, c0), so we don't manually need to identify them.

For example:

>>> import seq2seq
>>> from seq2seq.models import EncoderRNN, DecoderRNN, Seq2seq
>>> vocab_length = 10
>>> max_len = 5
>>> hidden_size = 7
>>> encoder = EncoderRNN(vocab_length, max_len, hidden_size, rnn_cell='lstm', bidirectional=True, variable_lengths=True)
>>> x = torch.LongTensor([[1, 2, 3, 0, 0]])
>>> output, hidden = encoder(x, input_lengths=[3])
>>> hidden
(tensor([[[-0.1805, -0.0206, -0.0415, -0.0419, -0.1093, -0.1006,  0.3364]],

        [[ 0.0262, -0.1567,  0.0960,  0.0525,  0.0513,  0.0613,  0.0911]]],
       grad_fn=<StackBackward>), tensor([[[-0.3646, -0.0378, -0.0712, -0.3447, -0.1835, -0.4023,  0.4809]],

        [[ 0.0828, -0.3054,  0.2106,  0.2864,  0.1017,  0.0886,  0.1711]]],
       grad_fn=<StackBackward>))
>>> h0, c0 = hidden
>>> h0
tensor([[[-0.1805, -0.0206, -0.0415, -0.0419, -0.1093, -0.1006,  0.3364]],

        [[ 0.0262, -0.1567,  0.0960,  0.0525,  0.0513,  0.0613,  0.0911]]],
       grad_fn=<StackBackward>)
>>> c0
tensor([[[-0.3646, -0.0378, -0.0712, -0.3447, -0.1835, -0.4023,  0.4809]],

        [[ 0.0828, -0.3054,  0.2106,  0.2864,  0.1017,  0.0886,  0.1711]]],
       grad_fn=<StackBackward>)

>>> hidden[0].size()   # since it's a bidirectional lstm we get 2 tensors of hidden_size
torch.Size([2, 1, 7])
>>> hidden[1].size()
torch.Size([2, 1, 7])

I'm not sure the maintainer is still active on the project. I hope this helped :)

pskrunner14 commented 6 years ago

Closing this for now

caozhen-alex commented 6 years ago

@pskrunner14 Hi, thank you very much for your explanation. I got the first problem. For the second one, why you let h0, c0 = hidden, I thought it should be h3, c3, or h5, c5.

pskrunner14 commented 6 years ago

@caozhen-alex it's just an arbitrary naming convention I've used. The important part is that hidden of the lstm is a tuple consisting of hidden and cell states respectively.

caozhen-alex commented 6 years ago

@pskrunner14 I c your point. Thank you very much for your clear explanation. Btw, How can I have a look at h0, c0 since i saw Pytorch documentation claiming they set them as 0 if they are not identified.