pytorch / text

Models, data loaders and abstractions for language processing, powered by PyTorch
https://pytorch.org/text
BSD 3-Clause "New" or "Revised" License
3.51k stars 811 forks source link

Include padding mask in generation #2096

Closed joecummings closed 1 year ago

joecummings commented 1 year ago

Bug

Expect batched input to match single input e.g.

  1. [seq1, ... seq_m] -> generate -> [output1, ...., output_m]
  2. [seq1] -> generate -> [output1]

Before this would not create the same output1. The issue was that the src_key_padding_mask was not being propagated forward.

Fix

Create padding mask and add it to model_kwargs and pass it to the forward function.