moon-hotel / TransformerTranslation

A Transformer Framework Based Translation Task
135 stars 37 forks source link

key_padding_mask 处理bug? #6

Closed weizhenhuan closed 1 year ago

weizhenhuan commented 1 year ago

MyTransformer.py的 multi_head_attention_forward()里,这里是不是错了? attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')) #

应该填充的位置是key_padding_mask.unsqueeze(1).unsqueeze(2) == 0的地方吧。 少了一个 == 0