facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

MSA Column attention's softmax axis is not the same as the padding_mask #57

Closed andrewcchang closed 3 years ago

andrewcchang commented 3 years ago

https://github.com/facebookresearch/esm/blob/5680ba75c6792694d4e12d77f01f1a0c9ee482c8/esm/axial_attention.py#L205-L225

If I understand the einsm correctly, the attn_weights is of shape [head_size, seq_len, batch_size, msa_row_size, msa_row_size] or [H, C, B, R, R].

If above is true, shouldn't we take the softmax at 'C' axis? i.e.

attn_probs = attn_weights.softmax(1)

Thank you!

tomsercu commented 3 years ago

Shape of attn_weights: Correct.
The softmax is correct in the implementation: in normal self-attention softmax goes over length of the sequence, in this case over the rows. See how it gets summed away in L222 -- that's the dimension you want to have normalized to sum to 1.

andrewcchang commented 3 years ago

Hi Tom, thank you so much for your quick response. Sorry that my original title was a bit misleading. I was initially a bit confused about the axis mismatch between padding_mask, attn_weights, and softmax.

If I understand the data preprocessing correctly, we pad sequences to a max_length. Therefore, the padding_mask will be TRUE at those 'C' (seq_len) dimension, which is dim=1. And in L216, we replace those padded positions in attn_weights with -10000. So I thought we should also take the softmax at dim=1. But thank you for pointing out that the einsum in L222 is summing over the rows; that's why we take softmax(-1).

Thank you!

tomsercu commented 3 years ago

yw! And thanks for looking into the details here!