jadore801120 / attention-is-all-you-need-pytorch

A PyTorch implementation of the Transformer model in "Attention is All You Need".
MIT License
8.78k stars 1.97k forks source link

masking is not complete #149

Open JianBingJuanDaCong opened 4 years ago

JianBingJuanDaCong commented 4 years ago

https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/5c0264915ab43485adc576f88971fc3d42b10445/transformer/Layers.py#L39-L40

As an numeric example, let's say an English sentence(len=12) is translated to De(len=11) number of attentional heads=8.

In above code, dec_enc_attn_mask is passed as the mask, which is of shape: (batch_size, 1 ,1, 12) to mask attention, which is of shape (batch_size, 8, 11, 12).

Basically, only the english sentence is masked, not the german sentence. And in the latter process, no additional masking is used, especially the precision calculation part. This will cause mis-calculation of precision which will further ruin the training.

Ar-Kareem commented 4 years ago

I don't think that's true. The German sentence in your example is being masked in the previous line.

Here is what I think happens. The inputs to the forward function are dec_input which is the full English sentence, and enc_output which is the full Germen sentence.

Then, the first MultiHeadAttention is called with dec_input and it's corresponding mask. The output is the attention of the encoded germen sentence saved to dec_output.

https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/5c0264915ab43485adc576f88971fc3d42b10445/transformer/Layers.py#L37-L38

So, the code you pointed out accepts the full English sentence and the masked Germen sentence. Thus, you only need to mask the English sentence in the second MultiHeadAttention

https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/5c0264915ab43485adc576f88971fc3d42b10445/transformer/Layers.py#L39-L40