awslabs / dgl-lifesci

Python package for graph neural networks in chemistry and biology
Apache License 2.0
728 stars 150 forks source link

JTVAE's `pretrain` script results raises due to mismatched dtypes #164

Open siboehm opened 2 years ago

siboehm commented 2 years ago

Running examples/generative_models/jtvae/pretrain.py without any arguments (which should pretrain on ZINC) raises an Error:

/home/simon/miniconda3/envs/jtvae_dgl/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: The input graph for the user-defined edge function does not contain valid edges
  return warnings.warn(message, category=category, stacklevel=1)
Traceback (most recent call last):
  File "/home/simon/Documents/ETH/Masters_thesis/chemical_CPA/embeddings/jtvae/pretrain.py", line 192, in <module>
    main(args)
  File "/home/simon/Documents/ETH/Masters_thesis/chemical_CPA/embeddings/jtvae/pretrain.py", line 86, in main
    beta=0,
  File "/home/simon/miniconda3/envs/jtvae_dgl/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/simon/miniconda3/envs/jtvae_dgl/lib/python3.7/site-packages/dgllife/model/model_zoo/jtvae.py", line 664, in forward
    word_loss, topo_loss, word_acc, topo_acc = self.decoder(batch_tree_graphs, tree_vec)
  File "/home/simon/miniconda3/envs/jtvae_dgl/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/simon/miniconda3/envs/jtvae_dgl/lib/python3.7/site-packages/dgllife/model/model_zoo/jtvae.py", line 278, in forward
    reduce_func=fn.sum('h_nei', 'sum_h'))
  File "/home/simon/miniconda3/envs/jtvae_dgl/lib/python3.7/site-packages/dgl/heterograph.py", line 4653, in pull
    v = utils.prepare_tensor(self, v, 'v')
  File "/home/simon/miniconda3/envs/jtvae_dgl/lib/python3.7/site-packages/dgl/utils/checks.py", line 35, in prepare_tensor
    name, g.idtype, g.device, F.dtype(data), F.context(data)))
dgl._ffi.base.DGLError: Expect argument "v" to have data type torch.int32 and device context cuda:0. But got torch.int64 and cuda:0.

Output of conda list:

# packages in environment at /home/simon/miniconda3/envs/jtvae_dgl: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 1_llvm conda-forge argcomplete 1.12.3 pyhd8ed1ab_2 conda-forge argon2-cffi 21.1.0 py37h5e8e339_2 conda-forge arrow-cpp 2.0.0 py37hc02b082_15_cpu conda-forge async_generator 1.10 py_0 conda-forge attrs 21.2.0 pyhd8ed1ab_0 conda-forge aws-c-common 0.4.59 h36c2ea0_1 conda-forge aws-c-event-stream 0.1.6 had2084c_6 conda-forge aws-checksums 0.1.10 h4e93380_0 conda-forge aws-sdk-cpp 1.8.70 h57dc084_1 conda-forge backcall 0.2.0 pyh9f0ad1d_0 conda-forge backports 1.0 py_2 conda-forge backports.functools_lru_cache 1.6.4 pyhd8ed1ab_0 conda-forge blas 2.112 mkl conda-forge blas-devel 3.9.0 12_linux64_mkl conda-forge bleach 4.1.0 pyhd8ed1ab_0 conda-forge boost 1.68.0 py37h8619c78_1001 conda-forge boost-cpp 1.68.0 h11c811c_1000 conda-forge brotli 1.0.9 h7f98852_6 conda-forge brotli-bin 1.0.9 h7f98852_6 conda-forge bzip2 1.0.8 h7f98852_4 conda-forge c-ares 1.18.1 h7f98852_0 conda-forge ca-certificates 2021.10.8 ha878542_0 conda-forge cairo 1.16.0 h18b612c_1001 conda-forge certifi 2021.10.8 py37h89c1867_1 conda-forge cffi 1.15.0 py37h036bc23_0 conda-forge charset-normalizer 2.0.7 pypi_0 pypi cloudpickle 2.0.0 pypi_0 pypi colorama 0.4.4 pyh9f0ad1d_0 conda-forge cudatoolkit 10.2.89 h8f6ccaa_9 conda-forge cycler 0.11.0 pyhd8ed1ab_0 conda-forge dbus 1.13.6 h48d8840_2 conda-forge debugpy 1.5.1 py37hcd2ae1e_0 conda-forge decorator 5.1.0 pyhd8ed1ab_0 conda-forge defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge dgl-cuda10.2 0.7.2 py37_0 dglteam dgllife 0.2.8 pypi_0 pypi entrypoints 0.3 py37hc8dfbb8_1002 conda-forge expat 2.4.1 h9c3ff4c_0 conda-forge fontconfig 2.13.1 he4413a7_1000 conda-forge freetype 2.10.4 h0708190_1 conda-forge future 0.18.2 pypi_0 pypi gettext 0.19.8.1 h73d1719_1008 conda-forge gflags 2.2.2 he1b5a44_1004 conda-forge glib 2.70.0 h780b84a_1 conda-forge glib-tools 2.70.0 h780b84a_1 conda-forge glog 0.4.0 h49b9bf7_3 conda-forge grpc-cpp 1.34.1 h2157cd5_4 gst-plugins-base 1.14.0 hbbd80ab_1 gstreamer 1.14.0 h28cd5cc_2 hyperopt 0.2.6 pypi_0 pypi icu 58.2 hf484d3e_1000 conda-forge idna 3.3 pypi_0 pypi importlib-metadata 4.8.2 py37h89c1867_0 conda-forge importlib_metadata 4.8.2 hd8ed1ab_0 conda-forge importlib_resources 5.4.0 pyhd8ed1ab_0 conda-forge ipykernel 6.5.0 py37h6531663_1 conda-forge ipython 7.29.0 py37h6531663_2 conda-forge ipython_genutils 0.2.0 py_1 conda-forge ipywidgets 7.6.5 pyhd8ed1ab_0 conda-forge jedi 0.18.0 py37h89c1867_3 conda-forge jinja2 3.0.3 pyhd8ed1ab_0 conda-forge joblib 1.1.0 pypi_0 pypi jpeg 9d h36c2ea0_0 conda-forge jsonschema 4.2.1 pyhd8ed1ab_0 conda-forge jupyter 1.0.0 py37h89c1867_7 conda-forge jupyter_client 6.1.12 pyhd8ed1ab_0 conda-forge jupyter_console 6.4.0 pyhd8ed1ab_1 conda-forge jupyter_core 4.9.1 py37h89c1867_1 conda-forge jupyterlab_pygments 0.1.2 pyh9f0ad1d_0 conda-forge jupyterlab_widgets 1.0.2 pyhd8ed1ab_0 conda-forge kiwisolver 1.3.2 py37h2527ec5_1 conda-forge krb5 1.19.2 h48eae69_3 conda-forge lcms2 2.12 hddcbb42_0 conda-forge ld_impl_linux-64 2.36.1 hea4e1c9_2 conda-forge libblas 3.9.0 12_linux64_mkl conda-forge libbrotlicommon 1.0.9 h7f98852_6 conda-forge libbrotlidec 1.0.9 h7f98852_6 conda-forge libbrotlienc 1.0.9 h7f98852_6 conda-forge libcblas 3.9.0 12_linux64_mkl conda-forge libcurl 7.80.0 h494985f_0 conda-forge libedit 3.1.20191231 he28a2e2_2 conda-forge libev 4.33 h516909a_1 conda-forge libevent 2.1.10 h28343ad_4 conda-forge libffi 3.4.2 h7f98852_5 conda-forge libgcc-ng 11.2.0 h1d223b6_11 conda-forge libgfortran-ng 11.2.0 h69a702a_11 conda-forge libgfortran5 11.2.0 h5c6108e_11 conda-forge libglib 2.70.0 h174f98d_1 conda-forge libiconv 1.16 h516909a_0 conda-forge liblapack 3.9.0 12_linux64_mkl conda-forge liblapacke 3.9.0 12_linux64_mkl conda-forge libnghttp2 1.43.0 ha19adfc_1 conda-forge libnsl 2.0.0 h7f98852_0 conda-forge libpng 1.6.37 h21135ba_2 conda-forge libprotobuf 3.14.0 h780b84a_0 conda-forge libsodium 1.0.18 h36c2ea0_1 conda-forge libssh2 1.10.0 ha35d2d1_2 conda-forge libstdcxx-ng 11.2.0 he4da1e4_11 conda-forge libthrift 0.13.0 hfb8234f_6 libtiff 4.2.0 hbd63e13_2 conda-forge libutf8proc 2.6.1 h7f98852_0 conda-forge libuuid 2.32.1 h7f98852_1000 conda-forge libuv 1.42.0 h7f98852_0 conda-forge libwebp-base 1.2.1 h7f98852_0 conda-forge libxcb 1.13 h7f98852_1004 conda-forge libxml2 2.9.9 h13577e0_2 conda-forge libzlib 1.2.11 h36c2ea0_1013 conda-forge llvm-openmp 12.0.1 h4bd325d_1 conda-forge lz4-c 1.9.3 h9c3ff4c_1 conda-forge markupsafe 2.0.1 py37h5e8e339_1 conda-forge matplotlib-base 3.4.3 py37h1058ff1_2 conda-forge matplotlib-inline 0.1.3 pyhd8ed1ab_0 conda-forge mistune 0.8.4 py37h5e8e339_1005 conda-forge mkl 2021.4.0 h8d4b97c_729 conda-forge mkl-devel 2021.4.0 ha770c72_730 conda-forge mkl-include 2021.4.0 h8d4b97c_729 conda-forge nbclient 0.5.8 pyhd8ed1ab_0 conda-forge nbconvert 6.3.0 py37h89c1867_1 conda-forge nbformat 5.1.3 pyhd8ed1ab_0 conda-forge ncurses 6.2 h58526e2_4 conda-forge nest-asyncio 1.5.1 pyhd8ed1ab_0 conda-forge networkx 2.6.3 pyhd8ed1ab_1 conda-forge notebook 6.4.5 pyha770c72_0 conda-forge numpy 1.21.4 py37h31617e3_0 conda-forge olefile 0.46 pyh9f0ad1d_1 conda-forge openjpeg 2.4.0 hb52868f_1 conda-forge openssl 3.0.0 h7f98852_2 conda-forge orc 1.6.6 h7950760_1 conda-forge packaging 21.0 pyhd8ed1ab_0 conda-forge pandas 1.3.4 py37he8f5f7f_1 conda-forge pandoc 2.16.1 h7f98852_0 conda-forge pandocfilters 1.5.0 pyhd8ed1ab_0 conda-forge parquet-cpp 1.5.1 1 conda-forge parso 0.8.2 pyhd8ed1ab_0 conda-forge pcre 8.45 h9c3ff4c_0 conda-forge pexpect 4.8.0 py37hc8dfbb8_1 conda-forge pickleshare 0.7.5 py37hc8dfbb8_1002 conda-forge pillow 8.2.0 py37h4600e1f_1 conda-forge pip 21.3.1 pyhd8ed1ab_0 conda-forge pixman 0.38.0 h516909a_1003 conda-forge prometheus_client 0.12.0 pyhd8ed1ab_0 conda-forge prompt-toolkit 3.0.22 pyha770c72_0 conda-forge prompt_toolkit 3.0.22 hd8ed1ab_0 conda-forge pthread-stubs 0.4 h36c2ea0_1001 conda-forge ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge pyarrow 2.0.0 py37h9425694_15_cpu conda-forge pycairo 1.20.1 py37hfff247e_1 conda-forge pycparser 2.21 pyhd8ed1ab_0 conda-forge pygments 2.10.0 pyhd8ed1ab_0 conda-forge pyparsing 3.0.6 pyhd8ed1ab_0 conda-forge pyqt 5.6.0 py37h13b7fb3_1008 conda-forge pyrsistent 0.18.0 py37h5e8e339_0 conda-forge python 3.7.12 hf930737_100_cpython conda-forge python-dateutil 2.8.2 pyhd8ed1ab_0 conda-forge python_abi 3.7 2_cp37m conda-forge pytorch 1.10.0 py3.7_cuda10.2_cudnn7.6.5_0 pytorch pytorch-mutex 1.0 cuda pytorch pytz 2021.3 pyhd8ed1ab_0 conda-forge pyzmq 22.3.0 py37h336d617_1 conda-forge qt 5.6.3 h8bf5577_3 qtconsole 5.2.0 pyhd8ed1ab_0 conda-forge qtpy 1.11.2 pyhd8ed1ab_0 conda-forge rdkit 2018.09.3 py37h9c20d5c_0 conda-forge re2 2020.11.01 h58526e2_0 conda-forge readline 8.1 h46c0cb4_0 conda-forge requests 2.26.0 pypi_0 pypi scikit-learn 1.0.1 pypi_0 pypi scipy 1.7.2 py37hf2a6cf1_0 conda-forge send2trash 1.8.0 pyhd8ed1ab_0 conda-forge setuptools 59.1.1 py37h89c1867_0 conda-forge sip 4.18.1 py37hf484d3e_1000 conda-forge six 1.16.0 pyh6c4a22f_0 conda-forge snappy 1.1.8 he1b5a44_3 conda-forge sqlite 3.36.0 h9cd32fc_2 conda-forge tbb 2021.4.0 h4bd325d_1 conda-forge terminado 0.12.1 py37h89c1867_1 conda-forge testpath 0.5.0 pyhd8ed1ab_0 conda-forge threadpoolctl 3.0.0 pypi_0 pypi tk 8.6.11 h27826a3_1 conda-forge tornado 6.1 py37h5e8e339_2 conda-forge tqdm 4.62.3 pyhd8ed1ab_0 conda-forge traitlets 5.1.1 pyhd8ed1ab_0 conda-forge typing_extensions 3.10.0.2 pyha770c72_0 conda-forge urllib3 1.26.7 pypi_0 pypi wcwidth 0.2.5 pyh9f0ad1d_2 conda-forge webencodings 0.5.1 py_1 conda-forge wheel 0.37.0 pyhd8ed1ab_1 conda-forge widgetsnbextension 3.5.2 py37h89c1867_0 conda-forge xorg-kbproto 1.0.7 h7f98852_1002 conda-forge xorg-libice 1.0.10 h7f98852_0 conda-forge xorg-libsm 1.2.3 hd9c2040_1000 conda-forge xorg-libx11 1.7.2 h7f98852_0 conda-forge xorg-libxau 1.0.9 h7f98852_0 conda-forge xorg-libxdmcp 1.1.3 h7f98852_0 conda-forge xorg-libxext 1.3.4 h7f98852_1 conda-forge xorg-libxrender 0.9.10 h7f98852_1003 conda-forge xorg-renderproto 0.11.1 h7f98852_1002 conda-forge xorg-xextproto 7.3.0 h7f98852_1002 conda-forge xorg-xproto 7.0.31 h7f98852_1007 conda-forge xz 5.2.5 h516909a_1 conda-forge zeromq 4.3.4 h9c3ff4c_1 conda-forge zipp 3.6.0 pyhd8ed1ab_0 conda-forge zlib 1.2.11 h36c2ea0_1013 conda-forge zstd 1.4.9 ha95c52a_0 conda-forge
siboehm commented 2 years ago

I've "fixed" this here: https://github.com/siboehm/dgl-lifesci/tree/jtvae by setting the dtype of the whole graph to int64.

Doubt that this is the root cause, but it's training fine now.

mufeili commented 2 years ago

Thanks for the update and I'm glad you solved that.

siboehm commented 2 years ago

The pretraining script in this repo is still broken though. Shouldn't we keep this issue open?

mufeili commented 2 years ago

Sure. I'll check it later. Thanks.