Closed andrewcchang closed 3 years ago
Shape of attn_weights: Correct.
The softmax is correct in the implementation: in normal self-attention softmax goes over length of the sequence, in this case over the rows. See how it gets summed away in L222 -- that's the dimension you want to have normalized to sum to 1.
Hi Tom, thank you so much for your quick response. Sorry that my original title was a bit misleading. I was initially a bit confused about the axis mismatch between padding_mask, attn_weights, and softmax.
If I understand the data preprocessing correctly, we pad sequences to a max_length. Therefore, the padding_mask will be TRUE at those 'C' (seq_len) dimension, which is dim=1. And in L216, we replace those padded positions in attn_weights with -10000. So I thought we should also take the softmax at dim=1. But thank you for pointing out that the einsum in L222 is summing over the rows; that's why we take softmax(-1).
Thank you!
yw! And thanks for looking into the details here!
https://github.com/facebookresearch/esm/blob/5680ba75c6792694d4e12d77f01f1a0c9ee482c8/esm/axial_attention.py#L205-L225
If I understand the einsm correctly, the
attn_weights
is of shape [head_size, seq_len, batch_size, msa_row_size, msa_row_size] or [H, C, B, R, R].If above is true, shouldn't we take the softmax at 'C' axis? i.e.
Thank you!