awslabs / dgl-lifesci

Python package for graph neural networks in chemistry and biology
Apache License 2.0
728 stars 150 forks source link

JTVAE #177

Open ParnianH98 opened 2 years ago

ParnianH98 commented 2 years ago

Hi! I want to run "Junction Tree Variational Autoencoder" on the original dataset used in the paper. I want to use the pre-trained model "JTVAE_ZINC_no_kl" but I don't know how exactly I should use this model for that purpose. I had some problems in extracting the results from the model (input parameters). Would you please give me a hint?

mufeili commented 2 years ago

Hi, what experiment did you want to reproduce? Could you provide things like experiment descriptions, reported results, etc?

ParnianH98 commented 2 years ago

I want to run the exact neural network of the paper. I have some SMILES data (ZINC dataset of the paper). I want to give these smiles data to the neural network (JTVAE) and see the output of the neural network in SMILES. If I see the result is promising, I want to extract and save the embeding of the dataset (I mean I want to store the output of encoder) I want to use the pre-trained model "JTVAE_ZINC_no_kl" for this purpose.

I tried: " modelJ = load_pretrained('JTVAE_ZINC_no_kl') modelJ.eval() " to build the model

Now, I do not know which type of input I should give to the "modelJ" to reach that purpose. Can that model provide me with this output? If yes, how could I do that?

I tried the smiles data but it was not successful. I know in the paper the encoders and decoders work with the JT and graphs which were produced from smiles data. But I do not know for this pre-trained model, which kind of input I have to provide.

I want some hints to reach the first paragraph's purposes.

mufeili commented 2 years ago

If you just want to try the reconstruction part, you can use this file and specify the path to your dataset with --test-path. Your dataset should be stored in a file with one SMILES string a line.

ph-mehdi commented 2 years ago

I have a problem with this too. When I run the following code:

model = load_pretrained('JTVAE_ZINC_no_kl') model.eval() smiles = [] for i in range(4): smiles.append(model(rdkit_mol=True))

or

model = load_pretrained('JTVAE_ZINC_no_kl') model.eval() out = model("CCCCCCC1=NN2C(=N)/C(=C\c3cc(C)n(-c4ccc(C)cc4C)c3C)C(=O)N=C2S1")

But I get an error

mufeili commented 2 years ago

Did you write this code snippet yourself or was it from somewhere else? You need to try following this file.

ph-mehdi commented 2 years ago

Did you write this code snippet yourself or was it from somewhere else? You need to try following this file.

I followed the example given here.

and I am using the link you provided, I will encounter the following errors in Colab:

usage: ipykernel_launcher.py [-h] [-tr TRAIN_PATH] [-te TEST_PATH]
                             [-m MODEL_PATH] [-w HIDDEN_SIZE] [-l LATENT_SIZE]
                             [-d DEPTH] [-pi PRINT_ITER] [-cpu]
ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-04a7a0cb-d3c6-4a96-be7e-db4746e9fbdc.json
An exception has occurred, use %tb to see the full traceback.

SystemExit: 2

/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py:2890: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)

And when I change args = parser.parse_args () to args = parser.parse_args (args = []), the error changes to the following error:

Downloading JTVAE_ZINC_no_kl_pre_trained.pth from https://data.dgl.ai/pre_trained/jtvae_ZINC_no_kl.pth...
Pretrained model loaded
/usr/local/lib/python3.7/dist-packages/dgl/base.py:45: DGLWarning: The input graph for the user-defined edge function does not contain valid edges
  return warnings.warn(message, category=category, stacklevel=1)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[<ipython-input-61-9c871d6b3bc0>](https://localhost:8080/#) in <module>()
     22     args = parser.parse_args(args = [])
     23 
---> 24     main(args)

[/usr/local/lib/python3.7/dist-packages/dgllife/utils/jtvae/vocab.py](https://localhost:8080/#) in get_index(self, smiles)
     72             The ID for the token.
     73         """
---> 74         return self.vmap[smiles]
     75 
     76     def get_smiles(self, idx):

KeyError: 'C1=NN=CN1'

Sorry for the many questions and thank you for answering.

mufeili commented 2 years ago

I followed the example given here.

Sorry the doc is outdated. We'll update it soon.

What's your RDKit version?

ph-mehdi commented 2 years ago

Sorry the doc is outdated. We'll update it soon.

What's your RDKit version?

RDKit version: 2021.09.5 dgl version: 0.6.1 dgllife version: 0.2.9

I also got an error about this link that I left in the previous quote. Is this error also due to outdated the document?

mufeili commented 2 years ago

Sorry the doc is outdated. We'll update it soon. What's your RDKit version?

RDKit version: 2021.09.5 dgl version: 0.6.1 dgllife version: 0.2.9

I also got an error about this link that I left in the previous quote. Is this error also due to outdated the document?

I don't think so. Could you try downgrading RDKit to 2018.09.3 and see if "KeyError: 'C1=NN=CN1'" still exists?

ParnianH98 commented 2 years ago

Thank you

Sorry the doc is outdated. We'll update it soon. What's your RDKit version?

RDKit version: 2021.09.5 dgl version: 0.6.1 dgllife version: 0.2.9 I also got an error about this link that I left in the previous quote. Is this error also due to outdated the document?

I don't think so. Could you try downgrading RDKit to 2018.09.3 and see if "KeyError: 'C1=NN=CN1'" still exists?

Thank you I ran this code with lder version of rdkit and I got no errors. Would you please give me a hint to see this autoencoder output? I mean I want to see when I give a SMILES to this network (autoencoder), What is the output of the autoencoder. Usually the output should be the same as the input SMILES but I want to test it myself on this network

mufeili commented 2 years ago

dec_smiles here is the output of the autoencoder.

ParnianH98 commented 2 years ago

Sure thanks. If I want to run this code on my own SMILES dataset, would you please explain what should I do?

mufeili commented 2 years ago

You can provide a file with one SMILES string a line to --test-path here.

ParnianH98 commented 2 years ago

Hi Thank you for your help. I have another question. I tried to train the neural network myself and used this. I got error "Expect argument "v" to have data type torch.int32 and device context cpu. But got torch.int64 and cpu." form the heterograph.py file line 4512. I tried to change some parts but I failed. I even tried a 32-bit operating system to run the code. Do you know where the problem is? This is the full colab error:

in main(args) 54 loss, kl_div, wacc, tacc, sacc, dacc = model( 55 batch_trees, batch_tree_graphs, batch_mol_graphs, stereo_cand_batch_idx, ---> 56 stereo_cand_labels, batch_stereo_cand_graphs, beta=args.beta) 57 optimizer.zero_grad() 58 loss.backward()

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, *kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/dgllife/model/model_zoo/jtvae.py in forward(self, batch_trees, batch_tree_graphs, batch_mol_graphs, stereo_cand_batch_idx, stereo_cand_labels, batch_stereo_cand_graphs, beta) 662 mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon 663 --> 664 word_loss, topo_loss, word_acc, topo_acc = self.decoder(batch_tree_graphs, tree_vec) 665 assm_loss, assm_acc = self.assm(batch_trees, batch_tree_graphs, mol_vec, tree_mess) 666

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, *kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/dgllife/model/model_zoo/jtvae.py in forward(self, tree_graphs, tree_vec) 276 # Message passing excluding the target 277 line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'), --> 278 reduce_func=fn.sum('h_nei', 'sum_h')) 279 line_tree_graphs.pull(v=eid, message_func=self.gru_message, 280 reduce_func=fn.sum('m', 'sum_gated_h'))

/usr/local/lib/python3.7/dist-packages/dgl/heterograph.py in pull(self, v, message_func, reduce_func, apply_node_func, etype, inplace) 4510 if inplace: 4511 raise DGLError('The inplace option is removed in v0.5.') -> 4512 v = utils.prepare_tensor(self, v, 'v') 4513 4514 if len(v) == 0:

/usr/local/lib/python3.7/dist-packages/dgl/utils/checks.py in prepare_tensor(g, data, name) 32 raise DGLError('Expect argument "{}" to have data type {} and device ' 33 'context {}. But got {} and {}.'.format( ---> 34 name, g.idtype, g.device, F.dtype(data), F.context(data))) 35 ret = data 36 else:

DGLError: Expect argument "v" to have data type torch.int32 and device context cpu. But got torch.int64 and cpu.

mufeili commented 2 years ago

DGLGraphs support two idtypes, int32 or int64. You can change the idtype of a DGLGraph with the APIs here. The node and edge index obtained by methods like g.nodes() and g.edges() will be tensors of the corresponding dtype. It's likely that you need to convert a DGLGraph to one of idtype int32 somewhere in your code.

ph-mehdi commented 2 years ago

Hello, thank you for answering the questions. I used this link to retrain jtvae. But I get the same error: DGLError: Expect argument "v" to have data type torch.int32 and device context cuda:0. But got torch.int64 and cuda:0.

I tried to change the DGL code. Do I have to switch from int64 to int32 everywhere? Or is there an easier solution?

DGLGraphs support two idtypes, int32 or int64. You can change the idtype of a DGLGraph with the APIs here. The node and edge index obtained by methods like g.nodes() and g.edges() will be tensors of the corresponding dtype. It's likely that you need to convert a DGLGraph to one of idtype int32 somewhere in your code.

mufeili commented 2 years ago

Hello, thank you for answering the questions. I used this link to retrain jtvae. But I get the same error: DGLError: Expect argument "v" to have data type torch.int32 and device context cuda:0. But got torch.int64 and cuda:0.

I tried to change the DGL code. Do I have to switch from int64 to int32 everywhere? Or is there an easier solution?

DGLGraphs support two idtypes, int32 or int64. You can change the idtype of a DGLGraph with the APIs here. The node and edge index obtained by methods like g.nodes() and g.edges() will be tensors of the corresponding dtype. It's likely that you need to convert a DGLGraph to one of idtype int32 somewhere in your code.

Did you change any code? If not, let me have a try myself.

ph-mehdi commented 2 years ago

Did you change any code? If not, let me have a try myself.

I didn't change anything here. Thankful

mufeili commented 2 years ago

Did you change any code? If not, let me have a try myself.

I didn't change anything here. Thankful

PR #178 should have fixed the issue. Could you try installing from source and see if there are any further issues?

ph-mehdi commented 2 years ago

thanks, I used the source and the error problem was fixed. But I get the following Warning while running: DGLWarning: The input graph for the user-defined edge function does not contain valid edges return warnings.warn(message, category=category, stacklevel=1)

I do not know if this problem is simply due to how the model uses torch.rnn or something else.

mufeili commented 2 years ago

thanks, I used the source and the error problem was fixed. But I get the following Warning while running: DGLWarning: The input graph for the user-defined edge function does not contain valid edges return warnings.warn(message, category=category, stacklevel=1)

I do not know if this problem is simply due to how the model uses torch.rnn or something else.

I observed that too and I think you can simply ignore that.