awslabs / dgl-lifesci

Python package for graph neural networks in chemistry and biology
Apache License 2.0
730 stars 151 forks source link

ACNN not working with dgl batch #117

Open manangoel99 opened 3 years ago

manangoel99 commented 3 years ago

To recreate

from dgllife.model.model_zoo.acnn import ACNN
import dgl
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
from dgllife.utils import ACNN_graph_construction_and_featurization
model = ACNN()

protein = Chem.MolFromPDBFile("./6LU7.pdb")
protein_pos = torch.Tensor(protein.GetConformer().GetPositions())

ligand1 = Chem.MolFromSmiles("O=C(CC(c1ccccc1)c1ccccc1)N1CCN(S(=O)(=O)c2ccccc2[N+](=O)[O-])CC1")
AllChem.EmbedMolecule(ligand1)

ligand2 = Chem.MolFromSmiles("Cc1cc(C(=O)Nc2ccc(OCC(N)=O)cc2)c(C)n1C1CC1")
AllChem.EmbedMolecule(ligand2)

pos1 = torch.Tensor(ligand1.GetConformer().GetPositions())
pos2 = torch.Tensor(ligand2.GetConformer().GetPositions())

g1 = ACNN_graph_construction_and_featurization(ligand1, protein, pos1, protein_pos)
g2 = ACNN_graph_construction_and_featurization(ligand2, protein, pos2, protein_pos)
print(g1, g2)
batch = dgl.graph([g1, g2])

This throws the following error

Traceback (most recent call last):
  File "ACNN.py", line 24, in <module>
    batch = dgl.graph([g1, g2])
  File "/home/manan/miniconda3/lib/python3.8/site-packages/dgl/convert.py", line 151, in graph
    u, v, urange, vrange = utils.graphdata2tensors(data, idtype)
  File "/home/manan/miniconda3/lib/python3.8/site-packages/dgl/utils/data.py", line 169, in graphdata2tensors
    src, dst = elist2tensor(data, idtype)
  File "/home/manan/miniconda3/lib/python3.8/site-packages/dgl/utils/data.py", line 28, in elist2tensor
    u, v = zip(*elist)
  File "/home/manan/miniconda3/lib/python3.8/site-packages/dgl/heterograph.py", line 1968, in __getitem__
    raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key))
dgl._ffi.base.DGLError: Invalid key "0". Must be one of the edge types.
manangoel99 commented 3 years ago

Also do pos1 and pos2 have to be the positions of the ligand after docking i.e. the docked pose?

mufeili commented 3 years ago

What's the version for dgl and dgllife? I tried your code snippet on a different protein pdb file, which seems to be working fine.

Also do pos1 and pos2 have to be the positions of the ligand after docking i.e. the docked pose?

If you want to use the model for docking, then right.

manangoel99 commented 3 years ago

The dgl version is 0.5.3 The dgllife version is 0.2.6 I tried using 4WTG instead and that also threw the same error.

mufeili commented 3 years ago

Why did you do the following?

batch = dgl.graph([g1, g2])

If you want to batch two graphs, you need to do

batch = dgl.batch([g1, g2])
manangoel99 commented 3 years ago

Apologies. I gave the incorrect snippet.

from dgllife.model.model_zoo.acnn import ACNN
import dgl
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
from dgllife.utils import ACNN_graph_construction_and_featurization
model = ACNN()

protein = Chem.MolFromPDBFile("./4wtg.pdb")
protein_pos = torch.Tensor(protein.GetConformer().GetPositions())

ligand1 = Chem.MolFromSmiles("O=C(CC(c1ccccc1)c1ccccc1)N1CCN(S(=O)(=O)c2ccccc2[N+](=O)[O-])CC1")
AllChem.EmbedMolecule(ligand1)

ligand2 = Chem.MolFromSmiles("Cc1cc(C(=O)Nc2ccc(OCC(N)=O)cc2)c(C)n1C1CC1")
AllChem.EmbedMolecule(ligand2)

pos1 = torch.Tensor(ligand1.GetConformer().GetPositions())
pos2 = torch.Tensor(ligand2.GetConformer().GetPositions())

g1 = ACNN_graph_construction_and_featurization(ligand1, protein, pos1, protein_pos)
g2 = ACNN_graph_construction_and_featurization(ligand2, protein, pos2, protein_pos)
batch = dgl.batch([g1, g2])

print(model(batch))

When I use dgl.batch I get the following error

Traceback (most recent call last):
  File "ACNN.py", line 25, in <module>
    print(model(batch))
  File "/home/manan/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/manan/miniconda3/lib/python3.8/site-packages/dgllife-0.2.6-py3.8.egg/dgllife/model/model_zoo/acnn.py", line 250, in forward
  File "/home/manan/miniconda3/lib/python3.8/site-packages/dgl/view.py", line 182, in __getitem__
    return self._graph._get_e_repr(self._etid, self._edges)[key]
KeyError: 'distance'

I tracked it down to https://github.com/awslabs/dgl-lifesci/blob/30aa61c8fd1a3cd23368da82af8631fb1dc60fcb/python/dgllife/model/model_zoo/acnn.py#L246 Something happens to the edges in the graph here which causes the KeyError. Before this line the Key distance exists in the graph.

mufeili commented 3 years ago

Can you share 4wtg.pdb for reproducing the issue?

manangoel99 commented 3 years ago

4WTG is just something I picked randomly https://www.rcsb.org/structure/4WTG GitHub doesn't allow me to attach a PDB file

mufeili commented 3 years ago

I think the code snippet does not involve passing the input to the model. Can you also include those lines of code?

manangoel99 commented 3 years ago

The last line of the code snippet

print(model(batch))

batch is input to the ACNN model of which model is an instance

mufeili commented 3 years ago

I've figured out what's going on.

  1. There's a chance that there are no edges for some edge types due to the high threshold. In that case, ACNN_graph_construction_and_featurization still constructs a placeholder for the edge type and the associated features.
  2. dgl.batch skips edge types without edges for feature concatenation. As a result, two edge types in batch do not have feature distance in your case.
  3. dgl.to_homogeneous copies a node/edge feature only when the feature exists for all nodes/edges.

We may want to support edge types with no edges for feature concatenation in dgl.batch, which can take longer time. For the time being, one thing we can do is to add a manual check in ACNN before invoking dgl.to_homogeneous and adds placeholders for the features if necessary. What do you think?