DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.42k stars 200 forks source link

Default node dimension changed from 66 to 67? #191

Open jasperhyp opened 1 year ago

jasperhyp commented 1 year ago

Hi! It looks like the default node featurizer now encodes molecules so that each node is of dimension 67. For compatibility purposes, could you kindly let me know which commit changed the default node dimension from 66 to 67 using Molecule or PackedMolecule? Or could you point me to the specific index in the node feature vector that I should remove in order to stay consistent with the previous default? Thanks!

Oxer11 commented 1 year ago

Hi, could you please provide the two versions of torchdrug that you find the default node dimensions are different? This will help us debug. Thanks~

BTW, there is a small modifcation in atom features in 36832c6 to remove time-consuming atom features. But this will only reduce the number of features.

jasperhyp commented 1 year ago

So I don't really remember what version it was but I created a PackedMolecule object at the start of this year, where the mol atom dim is 66. After using the newest torchdrug to again pack some other molecules and pretrain a model, when finetuning the model using the previous molecules, I got this error:

Traceback (most recent call last):
  File "train_torchdrug.py", line 331, in <module>
    preds = model(mol_graphs, labels)
  File "/home/.conda/envs/primekg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "train_torchdrug.py", line 265, in forward
    z_mols = self.mol_encoder(mols, mols.node_feature.float())
  File "/home/.conda/envs/primekg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/.conda/envs/primekg/lib/python3.8/site-packages/torchdrug-0.2.0-py3.8.egg/torchdrug/models/gin.py", line 76, in forward
    hidden = layer(graph, layer_input)
  File "/home/.conda/envs/primekg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/.conda/envs/primekg/lib/python3.8/site-packages/torchdrug-0.2.0-py3.8.egg/torchdrug/layers/conv.py", line 91, in forward
    update = self.message_and_aggregate(graph, input)
  File "/home/.conda/envs/primekg/lib/python3.8/site-packages/torchdrug-0.2.0-py3.8.egg/torchdrug/layers/conv.py", line 346, in message_and_aggregate
    update += edge_update
RuntimeError: The size of tensor a (66) must match the size of tensor b (67) at non-singleton dimension 1

which is due to the discrepancy between molecule atom dims. And I do remember I was getting 66-dim a while ago.

Oxer11 commented 1 year ago

@KiddoZhu Do you have any idea about this problem?

neurowelt commented 1 year ago

I think I'm experiencing a somewhat similar issue, maybe this will help out a bit. I started with torchdrug last week, so I did not work with any other version than the one I got recetnly -0.2.0.post1. I trained PropertyPrediction task on ClinTox, I was trying to see if I can convert generated molecule from smiles to data.Molecule and perform prediction, but got mismatch error from torch.

When I checked how the shapes behave for validation set molecules, I noticed the shape changes from 67 to 66 after I do molecule → smiles → molecule.

I did this:

samples = []
categories = set()

for sample in valid_set:
    category = tuple([v for k, v in sample.items() if k != 'graph'])

    if category not in categories:
        categories.add(category)

        print(f"Graph shape before: {sample['graph'].node_feature.shape}")
        sm_graph = sample['graph'].to_smiles()
        sample = data.Molecule.from_smiles(sm_graph)
        print(f"Graph shape after: {sample.node_feature.shape}\n")

        samples.append(sample)

samples = data.graph_collate(samples)
preds = F.sigmoid(task.predict(samples))

And got this output:

Graph shape before: torch.Size([15, 67])
Graph shape after: torch.Size([15, 66])

Graph shape before: torch.Size([58, 67])
Graph shape after: torch.Size([58, 66])

Graph shape before: torch.Size([68, 67])
Graph shape after: torch.Size([68, 66])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File [~/Programming/Python/miniforge3/envs/drugs/lib/python3.8/site-packages/torchdrug/tasks/property_prediction.py:135], in PropertyPrediction.predict(self, batch, all_loss, metric)
    133     graph = self.graph_construction_model(graph)
    134     print(graph)
--> 135 output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
    136 pred = self.mlp(output["graph_feature"])
    137 if self.normalization:
...
File [~/Programming/Python/miniforge3/envs/drugs/lib/python3.8/site-packages/torch/nn/modules/linear.py:114], in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (141x66 and 67x256)
jasperhyp commented 1 year ago

@neurowelt Thanks for the heads up! Interesting, so you got atom dim 67 from the ClinTox dataset but 66 using Molecule.from_smiles()? This is kinda the reverse of what I observed -- I got 67 if doing e.g. td.data.Molecule.from_smiles('C').atom_feature.shape[1]. I built from source and the version is 0.2.0 (commit a959f68)

neurowelt commented 1 year ago

@jasperhyp I built from the same commit and I redownloaded the ClinTox dataset, now dataset.node_feature_dim == 66.

In the meantime I did some additional digging in the from_molecule function which pointed me towards these lines: https://github.com/DeepGraphLearning/torchdrug/blob/a959f68f0c19f368be9e380f5a587de6970b3c67/torchdrug/data/molecule.py#L187-L189

If I understand correctly, these functions will depend on taks types (property prediciton, synthons, center id.), but all of them are located in feature.py, so if any shape is in some way different, perhaps some of the commits made on that file can explain that.

On the other hand, after I built the version you suggested, the redownloaded ClinTox shape change to 66 and I no longer had this issue.

The code below:

import rdkit
import torchdrug
print(rdkit.__version__)
print(torchdrug.__version__)

mol = data.Molecule.from_smiles('C')
print(mol.atom_feature.shape[1])

with open("./ClinTox.pkl", "rb") as fin:
    clintox = pickle.load(fin)
print(clintox.data[0].atom_feature.shape[1])

Outputs:

2021.09.4
0.2.0
66
66

That's what I got so far, I'll let know if I find something more.

jasperhyp commented 1 year ago

Ok then the problem must be rdkit version??

import rdkit
from torchdrug import data
print(rdkit.__version__)
print(torchdrug.__version__)

mol = data.Molecule.from_smiles('C')
print(mol.atom_feature.shape[1])

Outputs:

2022.09.5
0.2.0
67
neurowelt commented 1 year ago

I think we got it 😊 I kind of suspected that. I remember #170 , I reinstalled rdkit for that reason, but later on I also remember checking out pip show rdkit-pypi and getting 2022.*.* version. I think I just got them twisted and trained something in one version, tried opening in a different one.

Seems like it's rdkit version problem, not torchdrug.

jasperhyp commented 1 year ago

interesting, maybe it would be good to have an assertion in the feature generator! @Oxer11

AH-Merii commented 1 year ago

This means that anyone trying to follow the pretraining and finetuning tutorials won't be able to implement them, as the models expect an input dimension of 21. When using atom_feature="pretrain", with the newer version of rdkit it yeilds an input demension of 22 instead of 21.

AH-Merii commented 1 year ago

So I did a bit of investigating and the reason I am getting 22 instead of 21 has nothing to do with the version of rdkit. Is because allow_unknown=True in onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) in the snippet below.

@R.register("features.atom.pretrain")
def atom_pretrain(atom):
    """Atom feature for pretraining.

    Features:
        GetSymbol(): one-hot embedding for the atomic symbol

        GetChiralTag(): one-hot embedding for atomic chiral tag
    """
    return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \
           onehot(atom.GetChiralTag(), chiral_tag_vocab)

When I changed the value to False, I was able get 21 instead of 22.