Open marcosbodio opened 2 months ago
Hi @marcosbodio,
Thank you for your question.
no
.yes
. What you need to do is to replace the 2D graph + 2D GNN (GIN in our paper) with SMILES + BERT (or any other sequence encoder).Hi @chao1224, thank you for your answer. I see in your paper that you have Table 5 where you list results on DTA tasks with Davis and KIBA. These datasets contains SMILES of molecules, so how did you use GraphMVP (or GraphMVP-G, GraphMVP-C) on these datasets? It would be very useful to see the code, because that would clarify what is the proper way of using your model starting from the SMILES of molecule.
Hi @marcosbodio,
Sure, you can check this python script, specifically, this line assigns which dataset to use.
Hi @chao1224, I have looked at the script that you linked above, and I think that is for fine tuning your model, which I would prefer to avoid.
I was hoping to use a checkpoint of your model, for example output/3D_hybrid_02_masking/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.3_EBM_dot_prod_0.1_normalize_l2_detach_target_2_100_0/pretraining_model.pth
in GraphMVP_simple_features_for_classification.zip
(shared here)
I wonder if I could do something like this:
import torch
from rdkit import Chem
from rdkit.Chem.rdDistGeom import EmbedMolecule
from src_classification.GEOM_dataset_preparation import mol_to_graph_data_obj_simple_3D
smiles = 'Cn1cnc(c1)C(=O)c1ccc(CN2[C@H](Cc3ccccn3)C(=O)Nc3cc(Cl)ccc3C2=O)cc1'
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
EmbedMolecule(mol=mol)
data = mol_to_graph_data_obj_simple_3D(mol)
and then feed data to the model loaded from the checkpoint to compute an embedding of the SMILES. What do you think?
Hi @marcosbodio,
Yes, I think this is right if you want to use the 3D representation.
saver_dict = {
'model': molecule_model_2D.state_dict(),
'model_3D': molecule_model_3D.state_dict(),
'AE_2D_3D_model': AE_2D_3D_model.state_dict(),
'AE_3D_2D_model': AE_3D_2D_model.state_dict(),
}
model_3D
.model
above, then you can follow this pseudocode:
smiles = 'Cn1cnc(c1)C(=O)c1ccc(CN2[C@H](Cc3ccccn3)C(=O)Nc3cc(Cl)ccc3C2=O)cc1'
mol = Chem.MolFromSmiles(smiles)
data = mol_to_graph_data_obj_simple(mol)
where `mol_to_graph_data_obj_simple` is in this [function](https://github.com/chao1224/GraphMVP/blob/main/src_classification/datasets/molecule_datasets.py).
HI @chao1224 ,
I have tried to load one of your model checkpoint, but I do not see model_3D
. Here is what I did:
model_path = 'output/3D_hybrid_02_masking/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.3_EBM_dot_prod_0.1_normalize_l2_detach_target_2_100_0/pretraining_model.pth'
model = torch.load(f=model_path, map_location=torch.device('cpu'))
print(model.keys())
print('model_3D' in model)
where model_path
is from your file GraphMVP_simple_features_for_classification.zip (shared here)
The previous code prints the following:
odict_keys(['x_embedding1.weight', 'x_embedding2.weight', 'gnns.0.mlp.0.weight', 'gnns.0.mlp.0.bias', 'gnns.0.mlp.2.weight', 'gnns.0.mlp.2.bias', 'gnns.0.edge_embedding1.weight', 'gnns.0.edge_embedding2.weight', 'gnns.1.mlp.0.weight', 'gnns.1.mlp.0.bias', 'gnns.1.mlp.2.weight', 'gnns.1.mlp.2.bias', 'gnns.1.edge_embedding1.weight', 'gnns.1.edge_embedding2.weight', 'gnns.2.mlp.0.weight', 'gnns.2.mlp.0.bias', 'gnns.2.mlp.2.weight', 'gnns.2.mlp.2.bias', 'gnns.2.edge_embedding1.weight', 'gnns.2.edge_embedding2.weight', 'gnns.3.mlp.0.weight', 'gnns.3.mlp.0.bias', 'gnns.3.mlp.2.weight', 'gnns.3.mlp.2.bias', 'gnns.3.edge_embedding1.weight', 'gnns.3.edge_embedding2.weight', 'gnns.4.mlp.0.weight', 'gnns.4.mlp.0.bias', 'gnns.4.mlp.2.weight', 'gnns.4.mlp.2.bias', 'gnns.4.edge_embedding1.weight', 'gnns.4.edge_embedding2.weight', 'batch_norms.0.weight', 'batch_norms.0.bias', 'batch_norms.0.running_mean', 'batch_norms.0.running_var', 'batch_norms.0.num_batches_tracked', 'batch_norms.1.weight', 'batch_norms.1.bias', 'batch_norms.1.running_mean', 'batch_norms.1.running_var', 'batch_norms.1.num_batches_tracked', 'batch_norms.2.weight', 'batch_norms.2.bias', 'batch_norms.2.running_mean', 'batch_norms.2.running_var', 'batch_norms.2.num_batches_tracked', 'batch_norms.3.weight', 'batch_norms.3.bias', 'batch_norms.3.running_mean', 'batch_norms.3.running_var', 'batch_norms.3.num_batches_tracked', 'batch_norms.4.weight', 'batch_norms.4.bias', 'batch_norms.4.running_mean', 'batch_norms.4.running_var', 'batch_norms.4.num_batches_tracked'])
False
Am I loading the wrong checkpoint?
Hi @marcosbodio ,
I need to double-check the checkpoint files when I got time. Meanwhile, you should be able to use this checkpoint, which is one of the SOTA PaiNN pretraining methods (paper link)).
Hello, I would like to know if it is possible to use GraphMVP to encode molecule starting from their SMILES. I have read this issue, but that does not help much. I would be really grateful if you could provide some explanation, and ideally an example. Thank you!