Closed weizhenhuan closed 1 year ago
MyTransformer.py的 multi_head_attention_forward()里,这里是不是错了? attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')) #
应该填充的位置是key_padding_mask.unsqueeze(1).unsqueeze(2) == 0的地方吧。 少了一个 == 0
MyTransformer.py的 multi_head_attention_forward()里,这里是不是错了? attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')) #
应该填充的位置是key_padding_mask.unsqueeze(1).unsqueeze(2) == 0的地方吧。 少了一个 == 0