facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

[Feature] MSA Transformer implementation of self_attn_mask #579

Open rmaguire31 opened 1 year ago

rmaguire31 commented 1 year ago

I have noticed computation of embeddings or zero-shot scores for multiple variants using the same MSA is very inefficient, as attention must be computed multiple times for the wild-type and homolog sequences in the multiple sequence alignment.

I would like to specify a row attention mask, so that embedding information can only propagate from multiple sequence alignment context embeddings to the embeddings of N variants in parallel, without the embedding information propagating from the embeddings of multiple sequence context embeddings. This will reduce the number of attention computations required an increase inference throughput. It seems to me this setup is conceptually similar to causal self-attention, which masks out the attn_probs, so I believe I can do this with a row-wise self_attn_mask of the following shape:

\begin{pmatrix}
    1 & 0 & 0 & \dots & 1 & 1 & 1 & \dots \\
    0 & 1 & 0 & \dots & 1 & 1 & 1 & \dots \\
    0 & 0 & 1 & \dots & 1 & 1 & 1 & \dots \\
    \vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots \\
    0 & 0 & 0 & \dots & 1 & 1 & 1 & \dots \\
    0 & 0 & 0 & \dots & 1 & 1 & 1 & \dots \\
    0 & 0 & 0 & \dots & 1 & 1 & 1 & \dots \\
    \vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots
\end{pmatrix}

However, MSA Transformer as it is currently implemented does not support self_attn_mask. Is this something that could be added, or is MSA Transformer no longer in active development?