flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.98k stars 2.1k forks source link

Keeping the order of sentences in the batch in DocumentRNNEmbedding #865

Closed yosipk closed 5 years ago

yosipk commented 5 years ago

In DocumentRNNEmbedding the sentences in the batch are permuted (sorted) to take advantage of PyTorch's pack_padded_sequence and pad_packed_sequence require sentences ordered by length. In the rest of the code sentences are kept in the permuted order, so the outputs will also be permuted. This is fine when the order of the sentences in the batch doesn't matter or when labels are assigned to Sentence objects (like in Flair). However, for some tasks, where association between sentence and label (or an object in general) is not 1-1, it is assumed that order of sentences in the batch does not change.

Therefore, it would be good to keep the order of sentences in the batch. For that we'd need to keep sorting indices so we can undo the permutation due to sorting, and we can do that just after the RNN embeddings are calculated.

rahul7iitk commented 5 years ago

Hi, Just to make a small comment, in pytorch1.1, you can pass another parameter enforce_sorted=false to rnn.pack_padded_sequence. Then you won't need to pass your input as sorted array. Also i you can unsort your outputs again after passing them to rnn.pack_padded . See this - https://city.shaform.com/en/2019/01/15/sort-sequences-in-pytorch/

yosipk commented 5 years ago

I see that pack_padded_sequence has the option enforce_sorted, but I don't see it in pad_packed_sequence, and in the doc of that function it says "Batch elements will be ordered decreasingly by their length". So we do the sorting approach, the same as in the linked article: first sort the sentences according to their length and keep the "unsorting" permutation, which we apply at the end to restore the original order of the sentences in the batch.

What would be your proposition to keep the order of sentences in the batch on input and output the same?