chao1224 / GraphMVP

Pre-training Molecular Graph Representation with 3D Geometry, ICLR'22 (https://openreview.net/forum?id=xQUe1pOKPam)
https://chao1224.github.io/GraphMVP
MIT License
168 stars 20 forks source link

Encode molecules starting from their SMILES #26

Open marcosbodio opened 2 months ago

marcosbodio commented 2 months ago

Hello, I would like to know if it is possible to use GraphMVP to encode molecule starting from their SMILES. I have read this issue, but that does not help much. I would be really grateful if you could provide some explanation, and ideally an example. Thank you!

chao1224 commented 2 months ago

Hi @marcosbodio,

Thank you for your question.

marcosbodio commented 1 month ago

Hi @chao1224, thank you for your answer. I see in your paper that you have Table 5 where you list results on DTA tasks with Davis and KIBA. These datasets contains SMILES of molecules, so how did you use GraphMVP (or GraphMVP-G, GraphMVP-C) on these datasets? It would be very useful to see the code, because that would clarify what is the proper way of using your model starting from the SMILES of molecule.

chao1224 commented 1 month ago

Hi @marcosbodio,

Sure, you can check this python script, specifically, this line assigns which dataset to use.

marcosbodio commented 1 month ago

Hi @chao1224, I have looked at the script that you linked above, and I think that is for fine tuning your model, which I would prefer to avoid.

I was hoping to use a checkpoint of your model, for example output/3D_hybrid_02_masking/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.3_EBM_dot_prod_0.1_normalize_l2_detach_target_2_100_0/pretraining_model.pth in GraphMVP_simple_features_for_classification.zip (shared here)

I wonder if I could do something like this:

import torch
from rdkit import Chem
from rdkit.Chem.rdDistGeom import EmbedMolecule

from src_classification.GEOM_dataset_preparation import mol_to_graph_data_obj_simple_3D

smiles = 'Cn1cnc(c1)C(=O)c1ccc(CN2[C@H](Cc3ccccn3)C(=O)Nc3cc(Cl)ccc3C2=O)cc1'
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
EmbedMolecule(mol=mol)
data = mol_to_graph_data_obj_simple_3D(mol)

and then feed data to the model loaded from the checkpoint to compute an embedding of the SMILES. What do you think?

chao1224 commented 1 month ago

Hi @marcosbodio,

Yes, I think this is right if you want to use the 3D representation.

  1. When we create the checkpoints, we save the following modules (code):
            saver_dict = {
                'model': molecule_model_2D.state_dict(),
                'model_3D': molecule_model_3D.state_dict(),
                'AE_2D_3D_model': AE_2D_3D_model.state_dict(),
                'AE_3D_2D_model': AE_3D_2D_model.state_dict(),
            }
  2. What you wrote above can be fed into the model_3D.
  3. If you only want to use the 2D checkpoint, which is model above, then you can follow this pseudocode:
    
    smiles = 'Cn1cnc(c1)C(=O)c1ccc(CN2[C@H](Cc3ccccn3)C(=O)Nc3cc(Cl)ccc3C2=O)cc1'
    mol = Chem.MolFromSmiles(smiles)

data = mol_to_graph_data_obj_simple(mol)


where `mol_to_graph_data_obj_simple` is in this [function](https://github.com/chao1224/GraphMVP/blob/main/src_classification/datasets/molecule_datasets.py).
marcosbodio commented 1 month ago

HI @chao1224 ,

I have tried to load one of your model checkpoint, but I do not see model_3D. Here is what I did:

model_path = 'output/3D_hybrid_02_masking/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.3_EBM_dot_prod_0.1_normalize_l2_detach_target_2_100_0/pretraining_model.pth'
model = torch.load(f=model_path, map_location=torch.device('cpu'))
print(model.keys())
print('model_3D' in model)

where model_path is from your file GraphMVP_simple_features_for_classification.zip (shared here)

The previous code prints the following:

odict_keys(['x_embedding1.weight', 'x_embedding2.weight', 'gnns.0.mlp.0.weight', 'gnns.0.mlp.0.bias', 'gnns.0.mlp.2.weight', 'gnns.0.mlp.2.bias', 'gnns.0.edge_embedding1.weight', 'gnns.0.edge_embedding2.weight', 'gnns.1.mlp.0.weight', 'gnns.1.mlp.0.bias', 'gnns.1.mlp.2.weight', 'gnns.1.mlp.2.bias', 'gnns.1.edge_embedding1.weight', 'gnns.1.edge_embedding2.weight', 'gnns.2.mlp.0.weight', 'gnns.2.mlp.0.bias', 'gnns.2.mlp.2.weight', 'gnns.2.mlp.2.bias', 'gnns.2.edge_embedding1.weight', 'gnns.2.edge_embedding2.weight', 'gnns.3.mlp.0.weight', 'gnns.3.mlp.0.bias', 'gnns.3.mlp.2.weight', 'gnns.3.mlp.2.bias', 'gnns.3.edge_embedding1.weight', 'gnns.3.edge_embedding2.weight', 'gnns.4.mlp.0.weight', 'gnns.4.mlp.0.bias', 'gnns.4.mlp.2.weight', 'gnns.4.mlp.2.bias', 'gnns.4.edge_embedding1.weight', 'gnns.4.edge_embedding2.weight', 'batch_norms.0.weight', 'batch_norms.0.bias', 'batch_norms.0.running_mean', 'batch_norms.0.running_var', 'batch_norms.0.num_batches_tracked', 'batch_norms.1.weight', 'batch_norms.1.bias', 'batch_norms.1.running_mean', 'batch_norms.1.running_var', 'batch_norms.1.num_batches_tracked', 'batch_norms.2.weight', 'batch_norms.2.bias', 'batch_norms.2.running_mean', 'batch_norms.2.running_var', 'batch_norms.2.num_batches_tracked', 'batch_norms.3.weight', 'batch_norms.3.bias', 'batch_norms.3.running_mean', 'batch_norms.3.running_var', 'batch_norms.3.num_batches_tracked', 'batch_norms.4.weight', 'batch_norms.4.bias', 'batch_norms.4.running_mean', 'batch_norms.4.running_var', 'batch_norms.4.num_batches_tracked'])
False

Am I loading the wrong checkpoint?

chao1224 commented 1 month ago

Hi @marcosbodio ,

I need to double-check the checkpoint files when I got time. Meanwhile, you should be able to use this checkpoint, which is one of the SOTA PaiNN pretraining methods (paper link)).