kexinhuang12345 / MolTrans

MolTrans: Molecular Interaction Transformer for Drug Target Interaction Prediction (Bioinformatics)
https://academic.oup.com/bioinformatics/advance-article/doi/10.1093/bioinformatics/btaa880/5929692
BSD 3-Clause "New" or "Revised" License
186 stars 43 forks source link

About computing pairwise interaction #23

Open Tigerrr07 opened 1 year ago

Tigerrr07 commented 1 year ago

Hi Kexin, I have the follwing questions.

  1. I find the code for computing pairwise interaction a little complicated. Since you are using dot product, can I use torch.matmul(d_encoded_layers , p_encoded_layers.transpose(-1, -2)) directly instead of the following code?

https://github.com/kexinhuang12345/MolTrans/blob/47ac16b8c158b080ba6cdaec74cd7aa9c1332b73/models.py#L86-L100

Besides, the above code also confuses me a lot for the view operation in line 96, I tested it with a simple example, and it did not calculate the dot product between sub-structural pairs.

max_d = 2
max_p = 3

# batch_size 1 hidden dim 2
d_encoded_layers = torch.zeros(1, max_d, 2)
d_encoded_layers[0, 0, 0] = 1 
d_encoded_layers[0, 0, 1] = 1
p_encoded_layers = torch.zeros(1, max_p, 2)
p_encoded_layers[0, 0, 0] = 1
p_encoded_layers[0, 0, 1] = 2
p_encoded_layers[0, 1, 0] = 3
p_encoded_layers[0, 1, 1] = 4 
p_encoded_layers[0, 2, 0] = 5
p_encoded_layers[0, 2, 1] = 6

print(d_encoded_layers)
print(p_encoded_layers)

d_aug = torch.unsqueeze(d_encoded_layers, 2).repeat(1, 1, max_p, 1) # repeat along protein size
p_aug = torch.unsqueeze(p_encoded_layers, 1).repeat(1, max_d, 1, 1) # repeat along drug size
i = d_aug * p_aug
print(i)
i_v = i.view(1, -1, max_d, max_p) 
print(i_v)
i_v = torch.sum(i_v, dim = 1)
print(i_v)

output:

tensor([[[1., 1.],
         [0., 0.]]])
tensor([[[1., 2.],
         [3., 4.],
         [5., 6.]]])
tensor([[[[1., 2.],
          [3., 4.],
          [5., 6.]],

         [[0., 0.],
          [0., 0.],
          [0., 0.]]]])
tensor([[[[1., 2., 3.],
          [4., 5., 6.]],

         [[0., 0., 0.],
          [0., 0., 0.]]]])
tensor([[[1., 2., 3.],
         [4., 5., 6.]]])

The final i_v looks pointless, because the representation for drug sub1 is all zero. I think view operation changes the arangement of data. Maybe the following code is more correct?

i_s = torch.sum(i, dim=-1)
print(i_s) 

output:

tensor([[[ 3.,  7., 11.],
         [ 0.,  0.,  0.]]])

2.I think padding tokens should be filtered out of the interaction map $I$ before being fed into the CNN. I do this by passing the d_mask and p_mask:

  d_mask = d_mask.reshape(-1, self.max_d, 1)
  p_mask = p_mask.reshape(-1, 1, self.max_p)
  # mask padding tokens
  i.masked_fill_(~d_mask, 0)
  i.masked_fill_(~p_mask, 0)

Sorry to bother you.