JIT compile fail when using `functional.generalized_rspmm` with CUDA on Linux #3

Closed sbonner0 closed 2 years ago

sbonner0 commented 2 years ago


Most likely this is an error with torch drug itself however when I try to run any of the examples from the readme, the code will crash with the following error:

spmm.cuda.o.d -DTORCH_EXTENSION_NAME=spmm -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/user/miniconda3/envs/path/lib/python3.8/site-packages/torch/include -isystem /home/user/miniconda3/envs/path/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/miniconda3/envs/path/lib/python3.8/site-packages/torch/include/TH -isystem /home/user/miniconda3/envs/path/lib/python3.8/site-packages/torch/include/THC -isystem /opt/scp/software/CUDA/11.1.0/include -isystem /home/miniconda3/envs/path/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_70,code=compute_70 -gencode=arch=compute_70,code=sm_70 --compiler-options '-fPIC' -O3 -std=c++14 -c /home/user/miniconda3/envs/path/lib/python3.8/site-packages/torchdrug/layers/functional/extension/ -o spmm.cuda.o
FAILED: spmm.cuda.o
/opt/scp/software/CUDA/11.1.0/bin/nvcc --generate-dependencies-with-compile --dependency-output spmm.cuda.o.d -DTORCH_EXTENSION_NAME=spmm -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/user/miniconda3/envs/path/lib/python3.8/site-packages/torch/include -isystem /home/user/miniconda3/envs/path/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/user/miniconda3/envs/path/lib/python3.8/site-packages/torch/include/TH -isystem /home/user/miniconda3/envs/path/lib/python3.8/site-packages/torch/include/THC -isystem /opt/scp/software/CUDA/11.1.0/include -isystem /home/user/miniconda3/envs/path/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_70,code=compute_70 -gencode=arch=compute_70,code=sm_70 --compiler-options '-fPIC' -O3 -std=c++14 -c /home/user/miniconda3/envs/path/lib/python3.8/site-packages/torchdrug/layers/functional/extension/ -o spmm.cuda.o
/opt/software/CUDA/11.1.0/include/cuComplex.h: In function ‘float cuCabsf(cuFloatComplex)’:
/opt/software/CUDA/11.1.0/include/cuComplex.h:179:16: error: expected ‘)’ before numeric constant

This only occurs on a GPU linux machine, which is using CUDA 11.1 and GCC 10.3.

The conda env is as follows:

blas                      1.0                         mkl
boost                     1.74.0           py38hc10631b_3    conda-forge
boost-cpp                 1.74.0               h9359b55_0    conda-forge
brotlipy                  0.7.0           py38h497a2fe_1001    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
ca-certificates           2021.10.8            ha878542_0    conda-forge
cairo                     1.16.0            h3fc0475_1005    conda-forge
certifi                   2021.10.8        py38h578d9bd_1    conda-forge
cffi                      1.15.0           py38hd667e15_1
charset-normalizer        2.0.10             pyhd8ed1ab_0    conda-forge
colorama                  0.4.4              pyh9f0ad1d_0    conda-forge
cryptography              35.0.0           py38ha5dfef3_0    conda-forge
cudatoolkit               11.1.1               h6406543_8    conda-forge
cycler                    0.11.0             pyhd8ed1ab_0    conda-forge
decorator                 4.4.2                      py_0    conda-forge
easydict                  1.9                        py_0    conda-forge
fontconfig                2.13.1            hba837de_1005    conda-forge
freetype                  2.10.4               h0708190_1    conda-forge
glib                      2.69.1               h4ff587b_1
icu                       67.1                 he1b5a44_0    conda-forge
idna                      3.3                pyhd8ed1ab_0    conda-forge
intel-openmp              2021.4.0          h06a4308_3561
jinja2                    3.0.3              pyhd8ed1ab_0    conda-forge
joblib                    1.1.0              pyhd8ed1ab_0    conda-forge
jpeg                      9d                   h36c2ea0_0    conda-forge
kiwisolver                1.3.1            py38h2531618_0
ld_impl_linux-64          2.35.1               h7274673_9
libffi                    3.3                  he6710b0_2
libgcc-ng                 9.3.0               h5101ec6_17
libgfortran-ng            7.5.0               h14aa051_19    conda-forge
libgfortran4              7.5.0               h14aa051_19    conda-forge
libgomp                   9.3.0               h5101ec6_17
libiconv                  1.16                 h516909a_0    conda-forge
libpng                    1.6.37               h21135ba_2    conda-forge
libstdcxx-ng              9.3.0               hd4cf53a_17
libtiff                   4.0.10            hc3755c2_1005    conda-forge
libuuid                   2.32.1            h7f98852_1000    conda-forge
libuv                     1.42.0               h7f98852_0    conda-forge
libxcb                    1.13              h7f98852_1003    conda-forge
libxml2                   2.9.10               h68273f3_2    conda-forge
littleutils               0.2.2                      py_0    conda-forge
lz4-c                     1.9.3                h9c3ff4c_1    conda-forge
markupsafe                2.0.1            py38h497a2fe_0    conda-forge
matplotlib                3.2.2                         1    conda-forge
matplotlib-base           3.2.2            py38h5d868c9_1    conda-forge
mkl                       2021.4.0           h06a4308_640
mkl-service               2.4.0            py38h497a2fe_0    conda-forge
mkl_fft                   1.3.1            py38hd3c417c_0
mkl_random                1.2.2            py38h1abd341_0    conda-forge
ncurses                   6.3                  h7f8727e_2
networkx                  2.5.1              pyhd8ed1ab_0    conda-forge
ninja                     1.10.2               h4bd325d_0    conda-forge
numpy                     1.21.2           py38h20f2e39_0
numpy-base                1.21.2           py38h79a1101_0
ogb                       1.3.2              pyhd8ed1ab_0    conda-forge
olefile                   0.46               pyh9f0ad1d_1    conda-forge
openssl                   1.1.1m               h7f8727e_0
outdated                  0.2.1              pyhd8ed1ab_0    conda-forge
pandas                    1.2.5            py38h1abd341_0    conda-forge
pcre                      8.45                 h9c3ff4c_0    conda-forge
pillow                    6.2.1            py38h6b7be26_0    conda-forge
pip                       21.2.4           py38h06a4308_0
pixman                    0.38.0            h516909a_1003    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
pycairo                   1.20.1           py38hf61ee4a_0    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pyopenssl                 21.0.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.0.7              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1            py38h578d9bd_4    conda-forge
python                    3.8.12               h12debd9_0
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python_abi                3.8                      2_cp38    conda-forge
pytorch                   1.8.2           py3.8_cuda11.1_cudnn8.0.5_0    pytorch-lts
pytorch-scatter           2.0.8           py38_torch_1.8.0_cu111    pyg
pytz                      2021.3             pyhd8ed1ab_0    conda-forge
pyyaml                    5.4.1            py38h497a2fe_0    conda-forge
rdkit                     2020.09.5        py38h2bca085_0    conda-forge
readline                  8.1.2                h7f8727e_1
reportlab                 3.5.68           py38hadf75a6_0    conda-forge
requests                  2.27.1             pyhd8ed1ab_0    conda-forge
scikit-learn              1.0.2            py38h51133e4_1
scipy                     1.7.3            py38hc147768_0
setuptools                58.0.4           py38h06a4308_0
six                       1.16.0             pyh6c4a22f_0    conda-forge
sqlalchemy                1.3.23           py38h497a2fe_0    conda-forge
sqlite                    3.37.0               hc218d9a_0
threadpoolctl             3.0.0              pyh8a188c0_0    conda-forge
tk                        8.6.11               h1ccaba5_0
torchdrug                 0.1.2                  ha710097    milagraph
tornado                   6.1              py38h497a2fe_1    conda-forge
tqdm                      4.62.3             pyhd8ed1ab_0    conda-forge
typing_extensions         4.0.1              pyha770c72_0    conda-forge
urllib3                   1.26.8             pyhd8ed1ab_1    conda-forge
wheel                     0.37.1             pyhd3eb1b0_0
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.0.10               h7f98852_0    conda-forge
xorg-libsm                1.2.3             hd9c2040_1000    conda-forge
xorg-libx11               1.7.2                h7f98852_0    conda-forge
xorg-libxau               1.0.9                h7f98852_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h7f98852_1    conda-forge
xorg-libxrender           0.9.10            h7f98852_1003    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h7f98852_1002    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.5                h7b6447c_0
yaml                      0.2.5                h516909a_0    conda-forge
zlib                      1.2.11               h7f8727e_4
zstd                      1.4.9                ha95c52a_0    conda-forge

Any ideas how to get this to run?

Many thanks!

KiddoZhu commented 2 years ago

Hi! This looks like a compilation error in CUDA head files and not in TorchDrug.

I just skim through cuComplex.h and I think the cause is that the type float2 (a.k.a. cuFloatComplex) is not recognized correctly by the compiler. I guess the reason is either your GPU hardware doesn't support complex numbers or we miss some compilation flags to turn on the feature.

As we don't need to compile spmm/rspmm for complex tensors, maybe you can try to turn off the compilation for complex tensors in PyTorch JIT? I guess it might be some C++ macro but not sure what it is exactly.

sbonner0 commented 2 years ago

Hey @KiddoZhu, thanks so much for the prompt response!

So the GPU is just a V100 so isn't anything exotic. I will try and mess with the JIT though and let you know how it goes. Is the pytorch 1.8.2 dependency required or could I also try a newer version?

KiddoZhu commented 2 years ago

We use V100 too, so it sounds weird to me. The code is mainly developed and tested on V100 + PyTorch 1.8.1 + CUDA 10.2. We also know it is good on A100 + PyTorch 1.8.2 LTS + CUDA 11.1.

If you run this code with PyTorch 1.10 or newer, it will consume slightly more memory for the 0-th GPU, and the default batch size will cause OOM for a 32GB V100 on some datasets. Other than this, we don't see any problem for newer PyTorch versions.

sbonner0 commented 2 years ago

Hey @KiddoZhu I managed to solve this and it seemed to be some weird mismatch between conda and the native CUDA libraries installed on the HPC system. I now have it running using the python provided on the HPC and everything is working.

Thanks for your quick reply! I will now close the issue but I had a quick question about model check pointing -- are checkpoints only saved at the end of each epoch? If so, can this be changed to save after n update steps?

KiddoZhu commented 2 years ago

By the interface of TorchDrug, it is hard to dump checkpoint at the half of a epoch. But you can override the length of an epoch with the argument batch_per_epoch in solver.train. That can achieve a similar effect.