torchmd / torchmd-net

Training neural network potentials
MIT License
326 stars 73 forks source link

Failed to train ET on SPICE with torchmd-net-2.0 #302

Closed AndChenCM closed 7 months ago

AndChenCM commented 7 months ago

Dear developers,

I am trying to train ET on SPICE with the newest torchmd-net-2.0, but the training loss becomes NaN at the very begining. I tracked where NaN happened and found that the edge_vec output from the Optimized Distance module here contains zero, which causes NaN in edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1). I also found that the input position contains many zeros, but I haven't got time to inspect more on this. I believe with the previous Distance module, I have not encountered this problem when training ET on SPICE. My immediate thought is to to change the OptimizedDistance module back to the Distance module, but my attempt failed as I cannot find a suitable torch_cluster for the new torchmd-net environment.

Have you ever encountered this problem or could you reproduce this? I followed the instructions on this page to install torchmd-net 2.0 from source, and below is my current conda environment:

name: torchmd-net
channels:
  - conda-forge
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_kmp_llvm
  - _sysroot_linux-64_curr_repodata_hack=3=h69a702a_14
  - annotated-types=0.6.0=pyhd8ed1ab_0
  - aom=3.8.1=h59595ed_0
  - appdirs=1.4.4=pyh9f0ad1d_0
  - ase=3.22.1=pyhd8ed1ab_1
  - atk-1.0=2.38.0=hd4edc92_1
  - binutils_impl_linux-64=2.40=hf600244_0
  - binutils_linux-64=2.40=hbdbef99_2
  - blinker=1.7.0=pyhd8ed1ab_0
  - blosc=1.21.5=h0f2a231_0
  - brotli=1.1.0=hd590300_1
  - brotli-bin=1.1.0=hd590300_1
  - brotli-python=1.1.0=py311hb755f60_1
  - brunsli=0.1=h9c3ff4c_0
  - bzip2=1.0.8=hd590300_5
  - c-ares=1.27.0=hd590300_0
  - c-blosc2=2.13.2=hb4ffafa_0
  - ca-certificates=2024.2.2=hbcca054_0
  - cached-property=1.5.2=hd8ed1ab_1
  - cached_property=1.5.2=pyha770c72_1
  - cairo=1.18.0=h3faef2a_0
  - captum=0.6.0=pyhd8ed1ab_0
  - certifi=2024.2.2=pyhd8ed1ab_0
  - charls=2.4.2=h59595ed_0
  - charset-normalizer=3.3.2=pyhd8ed1ab_0
  - click=8.1.7=unix_pyh707e725_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - contourpy=1.2.0=py311h9547e67_0
  - cuda-cccl=12.0.90=ha770c72_1
  - cuda-cccl-impl=2.0.1=ha770c72_1
  - cuda-cccl_linux-64=12.0.90=ha770c72_1
  - cuda-cudart=12.0.107=hd3aeb46_8
  - cuda-cudart-dev=12.0.107=hd3aeb46_8
  - cuda-cudart-dev_linux-64=12.0.107=h59595ed_8
  - cuda-cudart-static=12.0.107=hd3aeb46_8
  - cuda-cudart-static_linux-64=12.0.107=h59595ed_8
  - cuda-cudart_linux-64=12.0.107=h59595ed_8
  - cuda-driver-dev=12.0.107=hd3aeb46_8
  - cuda-driver-dev_linux-64=12.0.107=h59595ed_8
  - cuda-libraries-dev=12.0.0=ha770c72_1
  - cuda-nvcc=12.0.76=hba56722_12
  - cuda-nvcc-dev_linux-64=12.0.76=ha770c72_1
  - cuda-nvcc-impl=12.0.76=h59595ed_1
  - cuda-nvcc-tools=12.0.76=h59595ed_1
  - cuda-nvcc_linux-64=12.0.76=hba56722_12
  - cuda-nvrtc=12.0.76=hd3aeb46_2
  - cuda-nvrtc-dev=12.0.76=hd3aeb46_2
  - cuda-nvtx=12.0.76=h59595ed_1
  - cuda-opencl=12.0.76=h59595ed_0
  - cuda-opencl-dev=12.0.76=ha770c72_0
  - cuda-profiler-api=12.0.76=ha770c72_0
  - cuda-version=12.0=hffde075_2
  - cudnn=8.8.0.121=h264754d_4
  - cycler=0.12.1=pyhd8ed1ab_0
  - dav1d=1.2.1=hd590300_0
  - docker-pycreds=0.4.0=py_0
  - exceptiongroup=1.2.0=pyhd8ed1ab_2
  - expat=2.5.0=hcb278e6_1
  - filelock=3.13.1=pyhd8ed1ab_0
  - flake8=7.0.0=pyhd8ed1ab_0
  - flask=3.0.2=pyhd8ed1ab_0
  - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
  - font-ttf-inconsolata=3.000=h77eed37_0
  - font-ttf-source-code-pro=2.038=h77eed37_0
  - font-ttf-ubuntu=0.83=h77eed37_1
  - fontconfig=2.14.2=h14ed4e7_0
  - fonts-conda-ecosystem=1=0
  - fonts-conda-forge=1=0
  - fonttools=4.49.0=py311h459d7ec_0
  - freetype=2.12.1=h267a509_2
  - fribidi=1.0.10=h36c2ea0_0
  - fsspec=2024.2.0=pyhca7485f_0
  - gcc=11.4.0=h7baecda_2
  - gcc_impl_linux-64=11.4.0=h7aa1c59_5
  - gcc_linux-64=11.4.0=hfd045f2_2
  - gdk-pixbuf=2.42.10=h829c605_4
  - gettext=0.21.1=h27087fc_0
  - giflib=5.2.1=h0b41bf4_3
  - gitdb=4.0.11=pyhd8ed1ab_0
  - gitpython=3.1.42=pyhd8ed1ab_0
  - gmp=6.3.0=h59595ed_0
  - gmpy2=2.1.2=py311h6a5fa03_1
  - graphite2=1.3.13=h58526e2_1001
  - graphviz=9.0.0=h78e8752_1
  - gtk2=2.24.33=h7f000aa_3
  - gts=0.7.6=h977cf35_4
  - gxx=11.4.0=h7baecda_2
  - gxx_impl_linux-64=11.4.0=h7aa1c59_5
  - gxx_linux-64=11.4.0=hfc1ae95_2
  - h5py=3.10.0=nompi_py311hebc2b07_101
  - harfbuzz=8.3.0=h3d44ed6_0
  - hdf5=1.14.3=nompi_h4f84152_100
  - icu=73.2=h59595ed_0
  - idna=3.6=pyhd8ed1ab_0
  - imagecodecs=2024.1.1=py311h089f87a_0
  - imageio=2.34.0=pyh4b66e23_0
  - importlib-metadata=7.0.1=pyha770c72_0
  - importlib_metadata=7.0.1=hd8ed1ab_0
  - iniconfig=2.0.0=pyhd8ed1ab_0
  - isodate=0.6.1=pyhd8ed1ab_0
  - itsdangerous=2.1.2=pyhd8ed1ab_0
  - jinja2=3.1.3=pyhd8ed1ab_0
  - joblib=1.3.2=pyhd8ed1ab_0
  - jxrlib=1.1=hd590300_3
  - kernel-headers_linux-64=3.10.0=h4a8ded7_14
  - keyutils=1.6.1=h166bdaf_0
  - kiwisolver=1.4.5=py311h9547e67_1
  - krb5=1.21.2=h659d440_0
  - lark-parser=0.12.0=pyhd8ed1ab_0
  - lazy_loader=0.3=pyhd8ed1ab_0
  - lcms2=2.16=hb7c19ff_0
  - ld_impl_linux-64=2.40=h41732ed_0
  - lerc=4.0.0=h27087fc_0
  - libabseil=20230802.1=cxx17_h59595ed_0
  - libaec=1.1.2=h59595ed_1
  - libavif16=1.0.4=h1dcd450_0
  - libblas=3.9.0=21_linux64_openblas
  - libbrotlicommon=1.1.0=hd590300_1
  - libbrotlidec=1.1.0=hd590300_1
  - libbrotlienc=1.1.0=hd590300_1
  - libcblas=3.9.0=21_linux64_openblas
  - libcublas=12.0.1.189=hd3aeb46_3
  - libcublas-dev=12.0.1.189=hd3aeb46_3
  - libcufft=11.0.0.21=hd3aeb46_2
  - libcufft-dev=11.0.0.21=hd3aeb46_2
  - libcufile=1.5.0.59=hd3aeb46_1
  - libcufile-dev=1.5.0.59=hd3aeb46_1
  - libcurand=10.3.1.50=hd3aeb46_1
  - libcurand-dev=10.3.1.50=hd3aeb46_1
  - libcurl=8.5.0=hca28451_0
  - libcusolver=11.4.2.57=hd3aeb46_2
  - libcusolver-dev=11.4.2.57=hd3aeb46_2
  - libcusparse=12.0.0.76=hd3aeb46_2
  - libcusparse-dev=12.0.0.76=hd3aeb46_2
  - libdeflate=1.19=hd590300_0
  - libedit=3.1.20191231=he28a2e2_2
  - libev=4.33=hd590300_2
  - libexpat=2.5.0=hcb278e6_1
  - libffi=3.4.2=h7f98852_5
  - libgcc-devel_linux-64=11.4.0=h922705a_105
  - libgcc-ng=13.2.0=h807b86a_5
  - libgd=2.3.3=h119a65a_9
  - libgfortran-ng=13.2.0=h69a702a_5
  - libgfortran5=13.2.0=ha4646dd_5
  - libglib=2.78.4=h783c2da_0
  - libgomp=13.2.0=h807b86a_5
  - libhwloc=2.9.3=default_h554bfaf_1009
  - libiconv=1.17=hd590300_2
  - libjpeg-turbo=3.0.0=hd590300_1
  - liblapack=3.9.0=21_linux64_openblas
  - libllvm14=14.0.6=hcd5def8_4
  - libmagma=2.7.2=h173bb3b_2
  - libmagma_sparse=2.7.2=h173bb3b_1
  - libnghttp2=1.58.0=h47da74e_1
  - libnpp=12.0.0.30=hd3aeb46_1
  - libnpp-dev=12.0.0.30=hd3aeb46_1
  - libnsl=2.0.1=hd590300_0
  - libnvjitlink=12.0.76=hd3aeb46_2
  - libnvjitlink-dev=12.0.76=hd3aeb46_2
  - libnvjpeg=12.0.0.28=h59595ed_1
  - libnvjpeg-dev=12.0.0.28=ha770c72_1
  - libopenblas=0.3.26=pthreads_h413a1c8_0
  - libpng=1.6.43=h2797004_0
  - libprotobuf=4.25.1=hf27288f_2
  - librsvg=2.56.3=he3f83f7_1
  - libsanitizer=11.4.0=h4dcbe23_5
  - libsqlite=3.45.1=h2797004_0
  - libssh2=1.11.0=h0841786_0
  - libstdcxx-devel_linux-64=11.4.0=h922705a_105
  - libstdcxx-ng=13.2.0=h7e041cc_5
  - libtiff=4.6.0=ha9c0a0a_2
  - libtorch=2.1.2=cuda120_h2aa5df7_301
  - libuuid=2.38.1=h0b41bf4_0
  - libuv=1.47.0=hd590300_0
  - libwebp=1.3.2=h658648e_1
  - libwebp-base=1.3.2=hd590300_0
  - libxcb=1.15=h0b41bf4_0
  - libxcrypt=4.4.36=hd590300_1
  - libxml2=2.12.5=h232c23b_0
  - libzlib=1.2.13=hd590300_5
  - libzopfli=1.0.3=h9c3ff4c_0
  - lightning=2.1.4=pyhd8ed1ab_0
  - lightning-utilities=0.10.1=pyhd8ed1ab_0
  - llvm-openmp=17.0.6=h4dfa4b3_0
  - llvmlite=0.42.0=py311ha6695c7_1
  - lz4-c=1.9.4=hcb278e6_0
  - magma=2.7.2=h51420fd_1
  - markupsafe=2.1.5=py311h459d7ec_0
  - matplotlib-base=3.8.3=py311h54ef318_0
  - mccabe=0.7.0=pyhd8ed1ab_0
  - mkl=2023.2.0=h84fe81f_50496
  - mpc=1.3.1=hfe3b2da_0
  - mpfr=4.2.1=h9458935_0
  - mpmath=1.3.0=pyhd8ed1ab_0
  - munkres=1.1.4=pyh9f0ad1d_0
  - nccl=2.20.3.1=h3a97aeb_0
  - ncurses=6.4=h59595ed_2
  - networkx=3.2.1=pyhd8ed1ab_0
  - nnpops=0.6=cuda120py311hcbe25e9_7
  - numba=0.59.0=py311h96b013e_1
  - numpy=1.26.4=py311h64a7726_0
  - ocl-icd=2.3.2=hd590300_0
  - openjpeg=2.5.2=h488ebb8_0
  - openssl=3.2.1=hd590300_0
  - opt_einsum=3.3.0=pyhc1e730c_2
  - packaging=23.2=pyhd8ed1ab_0
  - pandas=2.2.1=py311h320fe9a_0
  - pango=1.52.0=ha41ecd1_0
  - pathtools=0.1.2=py_1
  - patsy=0.5.6=pyhd8ed1ab_0
  - pcre2=10.42=hcad00b1_0
  - pillow=10.2.0=py311ha6c5da5_0
  - pip=24.0=pyhd8ed1ab_0
  - pixman=0.43.2=h59595ed_0
  - pluggy=1.4.0=pyhd8ed1ab_0
  - protobuf=4.25.1=py311h46cbc50_0
  - psutil=5.9.8=py311h459d7ec_0
  - pthread-stubs=0.4=h36c2ea0_1001
  - pycodestyle=2.11.1=pyhd8ed1ab_0
  - pydantic=2.6.3=pyhd8ed1ab_0
  - pydantic-core=2.16.3=py311h46250e7_0
  - pyflakes=3.2.0=pyhd8ed1ab_0
  - pynndescent=0.5.11=pyhca7485f_0
  - pyparsing=3.1.1=pyhd8ed1ab_0
  - pysocks=1.7.1=pyha2e5f31_6
  - pytest=8.0.2=pyhd8ed1ab_0
  - python=3.11.8=hab00c5b_0_cpython
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python-tzdata=2024.1=pyhd8ed1ab_0
  - python_abi=3.11=4_cp311
  - pytorch=2.1.2=cuda120_py311h25b6552_301
  - pytorch-lightning=2.1.3=pyhd8ed1ab_0
  - pytorch_geometric=2.4.0=pyhd8ed1ab_0
  - pytz=2024.1=pyhd8ed1ab_0
  - pywavelets=1.4.1=py311h1f0f07a_1
  - pyyaml=6.0.1=py311h459d7ec_1
  - rav1e=0.6.6=he8a937b_2
  - rdflib=7.0.0=pyhd8ed1ab_0
  - readline=8.2=h8228510_1
  - requests=2.31.0=pyhd8ed1ab_0
  - scikit-image=0.22.0=py311h320fe9a_2
  - scikit-learn=1.4.1.post1=py311hc009520_0
  - scipy=1.12.0=py311h64a7726_2
  - sentry-sdk=1.40.6=pyhd8ed1ab_0
  - setproctitle=1.3.3=py311h459d7ec_0
  - setuptools=65.3.0=pyhd8ed1ab_1
  - setuptools-scm=6.3.2=pyhd8ed1ab_0
  - setuptools_scm=6.3.2=hd8ed1ab_0
  - six=1.16.0=pyh6c4a22f_0
  - sleef=3.5.1=h9b69904_2
  - smmap=5.0.0=pyhd8ed1ab_0
  - snappy=1.1.10=h9fff704_0
  - statsmodels=0.14.1=py311h1f0f07a_0
  - svt-av1=1.8.0=h59595ed_0
  - sympy=1.12=pypyh9d50eac_103
  - sysroot_linux-64=2.17=h4a8ded7_14
  - tbb=2021.11.0=h00ab1b0_1
  - threadpoolctl=3.3.0=pyhc1e730c_0
  - tifffile=2024.2.12=pyhd8ed1ab_0
  - tk=8.6.13=noxft_h4845f30_101
  - tomli=2.0.1=pyhd8ed1ab_0
  - torchani=2.2.4=cuda120py311he2766f7_3
  - torchmetrics=1.2.1=pyhd8ed1ab_0
  - tqdm=4.66.2=pyhd8ed1ab_0
  - trimesh=4.1.7=pyhd8ed1ab_0
  - typing-extensions=4.10.0=hd8ed1ab_0
  - typing_extensions=4.10.0=pyha770c72_0
  - tzdata=2024a=h0c530f3_0
  - urllib3=2.2.1=pyhd8ed1ab_0
  - wandb=0.16.3=pyhd8ed1ab_0
  - werkzeug=3.0.1=pyhd8ed1ab_0
  - wheel=0.42.0=pyhd8ed1ab_0
  - xorg-kbproto=1.0.7=h7f98852_1002
  - xorg-libice=1.1.1=hd590300_0
  - xorg-libsm=1.2.4=h7391055_0
  - xorg-libx11=1.8.7=h8ee46fc_0
  - xorg-libxau=1.0.11=hd590300_0
  - xorg-libxdmcp=1.1.3=h7f98852_0
  - xorg-libxext=1.3.4=h0b41bf4_2
  - xorg-libxrender=0.9.11=hd590300_0
  - xorg-renderproto=0.11.1=h7f98852_1002
  - xorg-xextproto=7.3.0=h0b41bf4_1003
  - xorg-xproto=7.0.31=h7f98852_1007
  - xz=5.2.6=h166bdaf_0
  - yaml=0.2.5=h7f98852_2
  - zfp=1.0.1=h59595ed_0
  - zipp=3.17.0=pyhd8ed1ab_0
  - zlib=1.2.13=hd590300_5
  - zlib-ng=2.0.7=h0b41bf4_0
  - zstd=1.5.5=hfc55251_0

I use mamba 1.5.5 to install all the packages. The test was run on an A100-80G with command CUDA_VISIBLE_DEVICES=0 python scripts/train.py --conf examples/ET-SPICE.yaml. Note that the only change I made to the default ET-SPICE.yaml is set version: 1.1.3 and max_gradient: 50.94. The SPICE dataset is not pre-downloaded.

Any help on this would be greatly appreciated. Thank you!

Ming-an

RaulPPelaez commented 7 months ago

The line that you shared is there precisely to avoid the NaN when the distance is zero. It replaces zeros with ones so that 1/0 never happens. It is natural that edge_vec contains zeros, these can come from self interactions (the i,i pair if you'd like) and also "unused" pairs. The code is supposed to ignore those.

I cannot reproduce your issue on a 4090 (sadly I do not have access to an A100), so I am inclined to believe this is either environment or system dependent.

Could you check if this behavior also happens in a fresh environment with just "conda install torchmd-net" installed?

BTW, with "max_gradient" are you referring to the "gradient_clipping" option?

RaulPPelaez commented 7 months ago

BTW I noticed the SPICE 1.1.4 version was missing in the dataset class. I fixed it here: https://github.com/torchmd/torchmd-net/pull/303

AndChenCM commented 7 months ago

The line that you shared is there precisely to avoid the NaN when the distance is zero. It replaces zeros with ones so that 1/0 never happens. It is natural that edge_vec contains zeros, these can come from self interactions (the i,i pair if you'd like) and also "unused" pairs. The code is supposed to ignore those.

previously I first located that the x and vec output from the representation module contain NaN. If edge_vec is not the problem, should I check neighborhood embedding module or other places? image

I cannot reproduce your issue on a 4090 (sadly I do not have access to an A100), so I am inclined to believe this is either environment or system dependent.

Unfortunately I do not have a 4090 : (. To check if it is an environment issue, could you provide a yaml for me to test on?

Could you check if this behavior also happens in a fresh environment with just "conda install torchmd-net" installed?

Sure. I used mamba install torchmd-net in a fresh environment and mamba install wandb. It gives me the following environment, which is nearly the same with my previous one: env-torchmd.txt

Then I use the default ET-SPICE.yaml with version 1.1.3, and command CUDA_VISIBLE_DEVICES=0 torchmd-train --conf examples/ET-SPICE-test.yaml --log-dir outputs/spice-test --wandb-use True --wandb-name spice-test-ET --wandb-project nnp to run ET on SPICE. I copied the output in my terminal here: output.txt

Similar to my previous run, wandb does not record any metrics though the training has been going on for a while; I then shut down and use pdb to trace the intermediate outputs, the NaN still exists. image

BTW, with "max_gradient" are you referring to the "gradient_clipping" option?

By max_gradient I mean filter out the extra-large forces, like what is specfied in TensorNet-SPICE.yaml, not gradient_clipping.

AndChenCM commented 7 months ago

BTW, I also test it with TensorNet, but the results are same. So this is not a model dependent behaviour

AndChenCM commented 7 months ago

I think I might have some more clues. After setting masks, edge_vec[mask] still have zero vectors, which causes the following computation to have NaN values. Is it related to the warnings in my terminal output saying that You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set torch.set_float32_matmul_precision('medium' | 'high') which will trade-off precision for performance? image image

Besides, It is weird that wandb does not record anything. Usually I think at least there is a "training loss" showing in terminal.

RaulPPelaez commented 7 months ago

I am still unable to reproduce.
The tensorcore message you see is just a hint about the possibility you have to trade off a little bit of accuracy in exchange for performance: torch.set_float32_matmul_precision('medium' | 'high') https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html

Would you share a specific input to the model that produces NaNs for you? I am talking about a set of positions + atomic numbers.

RaulPPelaez commented 7 months ago

Also, does it happen also if you run forward using the CPU model? (a.i calling model = model.to(torch.float32) and similarly with the model inputs)

RaulPPelaez commented 7 months ago

Additionally, could you confirm the tests pass in your machine?

cd tests
pytest -x -s -v test*py
AndChenCM commented 7 months ago

Additionally, could you confirm the tests pass in your machine?

cd tests
pytest -x -s -v test*py

I cannot pass the first test. Here is the output: errors.txt

AndChenCM commented 7 months ago

I am still unable to reproduce. The tensorcore message you see is just a hint about the possibility you have to trade off a little bit of accuracy in exchange for performance: torch.set_float32_matmul_precision('medium' | 'high') https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html

Would you share a specific input to the model that produces NaNs for you? I am talking about a set of positions + atomic numbers.

I set the batch size to 1, and the NaN occurs at the third sample. The input atomic numbers are

tensor([ 8,  8,  8,  8,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  7,  6,
         6,  6,  6,  6,  6,  6,  6, 16,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1], device='cuda:0')

; the input pos is all zero:

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0', requires_grad=True). 

Emmmm I am not sure what causes the coordinates to be all zeros in this sample, the processing of SPICE is all default by the program.

RaulPPelaez commented 7 months ago

I see. Out tests do not cover the case of two atoms being at the exact same location. I am unsure whether this is supposed to be a valid input. @guillemsimeon do you think this would make sense in some situation?

I am convinced now that the issue somehow lies in SPICE providing a bogus sample. Could it simply be that your local download of the dataset got corrupted? Try removing the dataset_root folder to force the code to download the DS again.

AndChenCM commented 7 months ago

The SPICE 1.1.3 was freshly downloaded by the code before training yesterday. I can try a fresh download of SPICE 1.1.4 to see if the same issue happened.

AndChenCM commented 7 months ago

I just checked the raw SPICE-1.1.3.hdf5, there seem to be no conformations with coordinates being all zeros.

AndChenCM commented 7 months ago

It seems that the previous processed dataset somehow got corrupted. Re-downloading the dataset fix this issue. Sorry for the trouble!