wengong-jin / hgraph2graph

Hierarchical Generation of Molecular Graphs using Structural Motifs
MIT License
369 stars 109 forks source link

RuntimeError in in the polymers/ folder #1

Open goznxn opened 4 years ago

goznxn commented 4 years ago

Hi Wengong,

Thank you for sharing this wonderful work.

I try to use the code of polymers folder to generate molecule based on ZINK250K data. I got the data from the github of your junction tree folder. I first used get_vocab.py to get vocab data and use preprocess.py to get train data.

When i run the code, i got the following error, Namespace(anneal_rate=0.9, atom_vocab=<poly_hgraph.vocab.Vocab object at 0x2b5a6fb51e10>, batch_size=20, beta=0.3, clip_norm=20.0, depthG=20, depthT=20, diterG=5, diterT=1, dropout=0.0, embed_size=250, epoch=20, hidden_size=250, latent_size=24, load_epoch=-1, lr=0.001, print_iter=50, rnn_type='LSTM', save_dir='models/', save_iter=-1, train='train_processed/', vocab='zinc_vocab.txt') /home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/_reduction.py:46: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead. warnings.warn(warning.format(ret)) Model #Params: 5742K Traceback (most recent call last): File "vae_train.py", line 80, in loss, kl_div, wacc, iacc, tacc, sacc = model(batch, beta=beta) File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(input, kwargs) File "/gpfs/scratchfs1/zhg19014/motifgeneration/hgraph2graph/polymers/poly_hgraph/hgnn.py", line 76, in forward loss, wacc, iacc, tacc, sacc = self.decoder((root_vecs, tree_vecs, graph_vecs), graphs, tensors, orders) File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, *kwargs) File "/gpfs/scratchfs1/zhg19014/motifgeneration/hgraph2graph/polymers/poly_hgraph/decoder.py", line 254, in forward topo_scores = self.get_topo_score(src_tree_vecs, batch_idx, topo_vecs) File "/gpfs/scratchfs1/zhg19014/motifgeneration/hgraph2graph/polymers/poly_hgraph/decoder.py", line 137, in get_topo_score return self.topoNN( torch.cat([topo_vecs, topo_cxt], dim=-1) ).squeeze(-1) File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(input, kwargs) File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 92, in forward return F.linear(input, self.weight, self.bias) File "/home/zhg19014/.conda/envs/hgraph2graph/lib/python3.6/site-packages/torch/nn/functional.py", line 1406, in linear ret = torch.addmm(bias, input, weight.t()) RuntimeError: size mismatch, m1: [872 x 500], m2: [274 x 250] at /opt/conda/conda-bld/pytorch_1556653183467/work/aten/src/THC/generic/THCTensorMathBlas.cu:268

Thanks.