bentrevett / pytorch-pos-tagging

A tutorial on how to implement models for part-of-speech tagging using PyTorch and TorchText.
MIT License
177 stars 27 forks source link

Tutorial1 - BiLSTMPOSTagger,don't have pack pad? #8

Open actforjason opened 3 years ago

actforjason commented 3 years ago

For varying-length sequences,Why don't have pack pad?

def forward(self, text):
        embedded = self.dropout(self.embedding(text))

        outputs, (hidden, cell) = self.lstm(embedded)

        predictions = self.fc(self.dropout(outputs))

        return predictions
actforjason commented 3 years ago

Oh,there is a line self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx=pad_idx) point out padding_idx=pad_idx,but why isn't pack_padded_sequence&pad_packed_sequence needed?

bentrevett commented 3 years ago

You only need to use pack_padded_sequence and pad_packed_sequence if you want to do something with the final hidden (and cell) states, denoted in the code as hidden and cell, respectively from outputs, (hidden, cell) = self.lstm(embedded). This is because the packing and padding allows us to get the final hidden state from the last non-pad element in each sequence within the batch.

When doing POS tagging we want to do something with a sequence of hidden states, which we get from the outputs variable from outputs, (hidden, cell) = self.lstm(embedded). Yes, some of those hidden states will be from padded tokens, but I don't believe there is a nice way in PyTorch to get a sequence of hidden states up to a certain point. However, this doesn't really matter because when we pass the padding tag index to the CrossEntropyLoss we're telling PyTorch not to calculate losses over these pad tokens. We're only losing the time taken to run the LSTM over the pad tokens, which should be minimized by the BucketIterator that is designed to reduce the amount of padding within a batch.

The padding_idx=pad_idx argument to nn.Embedding does something different. It means whenever a token that has been numericalized to pad_idx is passed to the nn.Embedding it will return a tensor of zeros at that position, see: https://stackoverflow.com/a/61173091

actforjason commented 3 years ago

Thank you for your reply.

It means whenever a token that has been numericalized to pad_idx is passed to the nn.Embedding it will return a tensor of zeros at that position.

But why we need a tensor of zeros at the position of pad_idx?Will the parts of pad that haven't been packed but still here affect the training of LSTM (as well as bidirectional=True)?OR, what's the purpose of PACK other than reducing training time?

Besides,can I use pack_padded_sequence and pad_packed_sequence instead of padding_idx=pad_idx