microsoft / ProphetNet

A research project for natural language generation, containing the official implementations by MSRA NLC team.
MIT License
651 stars 104 forks source link

fix-bug: fix attn transpose bug #52

Open tqnwhz opened 2 years ago

tqnwhz commented 2 years ago

Hi, I seem to find a bug in the code.

In extract_features function of NgramTransformerDecoder, a transpose operation is applied to attn, which is the output of NgramTransformerDecoderLayer . The code snippet is as follows:

class NgramTransformerDecoder(FairseqIncrementalDecoder):
    def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
        # ......
        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_out['encoder_out'] if encoder_out is not None else None,
                encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
                incremental_state,
                self_attn_mask=self_attn_mask,
                ngram_mask_matrix=ngram_mask_matrix,
                i_buckets_main_stream=i_buckets_main_stream,
                i_bucket_relative_stream=i_bucket_relative_stream,
                real_positions=real_positions
            )
            inner_states.append(x)
        # TODO [(1+ngram)*T, B, C] -> [B, (1+ngram)*T, C]
        x_list = x.transpose(0, 1).chunk(1 + self.ngram, 1)
        if attn is not None:
            attn_list = attn.transpose(0, 1).chunk(1 + self.ngram, 1)
        else:
            attn_list = None

        return x_list, {'attn': attn_list}

As can be seen from the code comments, it's purpose is to change the dims from [(1+ngram)*T, B, C] to [B, (1+ngram)*T, C]. The variable attn, from NgramTransformerDecoderLayer, is the second result returned by its encoder_attn(fairseq.modules.MultiheadAttention).

In fairseqv0.9.0, the code snippet of MultiheadAttention's forward function is as follows:

class MultiheadAttention(nn.Module):
    def forward(
        self,
        # ...
    ):
        # ......
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)
        else:
            attn_weights = None

        return attn, attn_weights

It can be seen that, the second result of forward function(attn_weights), has the shape (bsz, self.num_heads, tgt_len, src_len) originally. After transpose and mean operator, it has the shape (bsz, tgt_len, src_len), which is the actual shape of attn mentioned in extract_features rather than (1+ngram)*T, B, C described in the comment. BTW, shape and transpose of x in extract_features is right. And the attn is not actually used during training and inferencing. So I guess it's the reason why it has not been found for 2 years.

But if one wants to some modification and needs to use the variable attn , like me, will find it has a confusing shape caused by the transpose operator. And it does take me some time to find the bug.

Hoping the PR can be merged.

ghost commented 2 years ago

CLA assistant check
All CLA requirements met.