dauparas / ProteinMPNN

Code for the ProteinMPNN paper
MIT License
934 stars 284 forks source link

About the variable "order_mask_backward" in protein_mpnn_utils.py, line 1085 #34

Open phonez opened 1 year ago

phonez commented 1 year ago

Hi, I'm learning your ProteinMPNN framework. When going through your script, protein_mpnn_utils.py, I am confused with the variable order_mask_backward, which is defined and used in line 1085 and line 1086, respectively.

order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) 

As I understand it (according to line 1086, i.e., the second line above), order_mask_backward should be the tensor that records: for each residue, which residues are decoded before it in the reverse order (with the value of 1, meaning that these residues can be seen). In that case, index along dim=-1 and dim=-2 both represent the position of residues so that order_mask_backward can be gathered by E_idx (E_idx is a tensor that records for each residue, which residues are recognized as neighbors, with a shape of [num_batch, num_residues, num_neighbors]).

However, to my understanding, order_mask_backward defined in the first line above records that, for decoding pair (q, p), whether there exists corresponding residue pair (i, j), subject to i > j. If exists, the value is 1, else 0. Here, q and p is the index along dim=-2 and dim=-1 of tensor order_mask_backward respectively, i and j is the position of residue in the sequence.

To clarify, take an easy example as follows.

import torch
import torch.nn.functional as F

num = 4 # num of residues
a = torch.Tensor([2,3,0,1]).long() # random decoding order, i.e., a[position_of_residue] = value of decoding order

one_hot_a = F.one_hot(a, num_classes=num).float()
one_hot_a = one_hot_a.unsqueeze(0)
result = torch.einsum('ij, biq, bjp->bqp', (1-torch.triu(torch.ones(num, num))), one_hot_a, one_hot_a) #  given by line 1085
result
tensor([[[0., 0., 1., 1.],
         [1., 0., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 1., 0.]]])

For example, result[0][0][2] = 1, meaning that there exists such residue pair (i, j, s.t. i > j) that has a decoding order of (0, 2). In fact, residue 2 and residue 0 constitute such pair that could satisfy the above conditions. This example proves that my understanding of the variable order_mask_backward given by line 1085 may be right. However, in that case, order_mask_backward does not go along with E_idx in line 1086 because the index along dim=-2 and dim=-1 does not represent the position of residues.

Modification of the equation in torch.einsum() as follows may solve that problem.

torch.einsum('ji, bqi, bpj->bqp', (1-torch.triu(torch.ones(num, num)), one_hot_a, one_hot_a))
tensor([[[0., 1., 0., 0.],
         [0., 0., 0., 0.],
         [1., 1., 0., 1.],
         [1., 1., 0., 0.]]])

In the above result, result[0][0][1] = 1, meaning that the decoding order of residue 1 is before (in the reverse order, i.e., decoding backwards) residue 0 (3 and 2, respectively), so that residue 1 can be seen when decoding residue 0.

I'm not sure if my understanding is correct and if it has an influence on the model training result.

ak422 commented 1 year ago

And for backward, it shoud be the node j <= i, so the code should be: order_mask_backward = torch.einsum('ji, bqi, bpj->bqp', (1 - torch.triu(torch.ones(mask_size, mask_size, device=device), diagonal=1)), permutation_matrix_reverse, permutation_matrix_reverse)

MaoSihong commented 1 year ago

您好,我正在学习您的 ProteinMPNN 框架。order_mask_backward在查看您的脚本 protein_mpnn_utils.py 时,我对分别在第 1085 行和第 1086 行中定义和使用的变量感到困惑。

order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) 

据我理解(根据第1086行,即上面的第二行), order_mask_backward _应该是_记录的张量:对于每个残基,哪些残基在它之前以相反的顺序解码(值为1,意味着这些残基看得到)。在这种情况下,index alongdim=-1dim=-2both 都代表残基的位置,因此order_mask_backward可以通过E_idx(E_idx是一个张量,记录每个残基,哪些残基被识别为邻居,形状为[num_batch, num_residues, num_neighbors])。

但是,据我理解,order_mask_backward上面第一行的定义记录的是,对于解码对(q,p),是否存在对应的残差对(i,j),以i > j为准。如果存在则值为1,否则为0。这里,q和分别p是张量沿dim=-2dim=-1的索引order_mask_backwardij是残差在序列中的位置。

为了澄清,举一个简单的例子如下。

import torch
import torch.nn.functional as F

num = 4 # num of residues
a = torch.Tensor([2,3,0,1]).long() # random decoding order, i.e., a[position_of_residue] = value of decoding order

one_hot_a = F.one_hot(a, num_classes=num).float()
one_hot_a = one_hot_a.unsqueeze(0)
result = torch.einsum('ij, biq, bjp->bqp', (1-torch.triu(torch.ones(num, num))), one_hot_a, one_hot_a) #  given by line 1085
result
tensor([[[0., 0., 1., 1.],
         [1., 0., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 1., 0.]]])

例如,result[0][0][2] = 1,表示存在解码顺序为 (0, 2) 的残差对 (i, j, st i > j)。事实上,残基 2 和残基 0 构成了满足上述条件的一对。这个例子证明我order_mask_backward对第1085行给出的变量的理解可能是正确的。然而,在那种情况下,order_mask_backward并不会在第 1086 行中出现E_idx,因为索引 alongdim=-2并不dim=-1 代表残基的位置。

如下修改 torch.einsum() 中的等式可以解决该问题。

torch.einsum('ji, bqi, bpj->bqp', (1-torch.triu(torch.ones(num, num)), one_hot_a, one_hot_a))
tensor([[[0., 1., 0., 0.],
         [0., 0., 0., 0.],
         [1., 1., 0., 1.],
         [1., 1., 0., 0.]]])

在上面的结果中, result[0][0][1] = 1,意味着残差1的解码顺序在残差0之前(顺序相反,即向后解码)(分别为3和2),因此在解码残差0时可以看到残差1。

不知道我的理解对不对,对模型训练结果有没有影响。

in my view, the result[0][0], which equals to tensor([0,0,1,1]), indicates that residue type at idx 2,3 are known while decoding at position 0