isayevlab / AIMNet2

MIT License
87 stars 24 forks source link

Struggling With Autograd #15

Closed corinwagen closed 9 months ago

corinwagen commented 9 months ago

Hey all, congrats on a fantastic paper. I'm trying out some of the scripts included here, but I'm struggling to get autograd to work (for gradient/Hessian)...

Here is my minimal reproducible example, using water as a test case:

import torch
import numpy as np

model_name = "models/aimnet2_wb97m-d3_ens.jpt"
#model_name = "models/aimnet2_b973c_ens.jpt"

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = torch.jit.load(model_name, map_location=device)

input_atomic_numbers = [8, 1, 1]
input_coord = [
    [ 0.0000,  0.0000,  0.1178],
    [ 0.0000,  0.7555, -0.4712],
    [ 0.0000, -0.7555, -0.4712],
]
input_charge = 0

numbers = torch.as_tensor([input_atomic_numbers], device=device)
coord = torch.as_tensor(input_coord, dtype=torch.float, device=device).view(1, numbers.shape[1], 3)
charge = torch.as_tensor([input_charge], dtype=torch.float, device=device)
_in = dict(coord=coord, numbers=numbers, charge=charge)

with torch.jit.optimized_execution(False):
    _in['coord'].requires_grad_(True)
    _out = model(_in)
    e = _out['energy']
    f = torch.autograd.grad(e, _in['coord'])[0]

print(e)
print(f)

Energy works fine, but I get the following error for torch.autograd.grad():

Traceback (most recent call last):
  File "/Users/cwagen/.../example.py", line 27, in <module>
    f = torch.autograd.grad(e, _in['coord'])[0]
  File "/opt/miniconda3/envs/aimnet2_pytorch200/lib/python3.10/site-packages/torch/autograd/__init__.py", line 303, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

This is confusing, because when I print _in before the energy call, it seems to require grad...

{'coord': tensor([[[ 0.0000,  0.0000,  0.1178],
         [ 0.0000,  0.7555, -0.4712],
         [ 0.0000, -0.7555, -0.4712]]], requires_grad=True), 'numbers': tensor([[8, 1, 1]]), 'charge': tensor([0.])}

Here's the result of conda list - I also tried the latest version of Pytorch w/ Python 3.11 and I got the same error:

# packages in environment at /opt/miniconda3/envs/aimnet2_pytorch200:
#
# Name                    Version                   Build  Channel
bzip2                     1.0.8                h10d778d_5    conda-forge
ca-certificates           2023.11.17           h8857fd0_0    conda-forge
filelock                  3.13.1             pyhd8ed1ab_0    conda-forge
gmp                       6.3.0                h93d8f39_0    conda-forge
gmpy2                     2.1.2           py310hb691cb2_1    conda-forge
icu                       73.2                 hf5e326d_0    conda-forge
jinja2                    3.1.3              pyhd8ed1ab_0    conda-forge
libabseil                 20230802.1      cxx17_h048a20a_0    conda-forge
libblas                   3.9.0           20_osx64_openblas    conda-forge
libcblas                  3.9.0           20_osx64_openblas    conda-forge
libcxx                    16.0.6               hd57cbcb_0    conda-forge
libffi                    3.4.2                h0d85af4_5    conda-forge
libgfortran               5.0.0           13_2_0_h97931a8_1    conda-forge
libgfortran5              13.2.0               h2873a65_1    conda-forge
libhwloc                  2.9.3           default_h24e0189_1009    conda-forge
libiconv                  1.17                 hd75f5a5_2    conda-forge
liblapack                 3.9.0           20_osx64_openblas    conda-forge
libopenblas               0.3.25          openmp_hfef2a42_0    conda-forge
libprotobuf               4.24.4               h0ee05dc_0    conda-forge
libsqlite                 3.44.2               h92b6c6a_0    conda-forge
libuv                     1.46.0               h0c2f820_0    conda-forge
libxml2                   2.11.6               hc0ae0f7_0    conda-forge
libzlib                   1.2.13               h8a1eda9_5    conda-forge
llvm-openmp               17.0.6               hb6ac08f_0    conda-forge
markupsafe                2.1.3           py310h6729b98_1    conda-forge
mkl                       2022.2.1         h44ed08c_16952    conda-forge
mpc                       1.3.1                h81bd1dd_0    conda-forge
mpfr                      4.2.1                h0c69b56_0    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
ncurses                   6.4                  h93d8f39_2    conda-forge
networkx                  3.2.1              pyhd8ed1ab_0    conda-forge
numpy                     1.26.3          py310h4bfa8fc_0    conda-forge
openssl                   3.2.0                hd75f5a5_1    conda-forge
pip                       23.3.2             pyhd8ed1ab_0    conda-forge
python                    3.10.13         h00d2728_1_cpython    conda-forge
python_abi                3.10                    4_cp310    conda-forge
pytorch                   2.0.0           cpu_mkl_py310hed029b9_104    conda-forge
readline                  8.2                  h9e318b2_1    conda-forge
scipy                     1.11.4          py310h3f1db6d_0    conda-forge
setuptools                69.0.3             pyhd8ed1ab_0    conda-forge
sleef                     3.5.1                h6db0672_2    conda-forge
sympy                     1.12            pypyh9d50eac_103    conda-forge
tbb                       2021.11.0            he51d815_0    conda-forge
tk                        8.6.13               h1abcd95_1    conda-forge
typing_extensions         4.9.0              pyha770c72_0    conda-forge
tzdata                    2023d                h0c530f3_0    conda-forge
wheel                     0.42.0             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h775f41a_0    conda-forge

I'm new to PyTorch, so let me know what I'm doing wrong!

SimonBoothroyd commented 9 months ago

@corinwagen I had the same issue with the ensemble models, but managed to get it work by swapping to a single model (e.g. models/aimnet2_wb97m-d3_1.jpt)

corinwagen commented 9 months ago

Thanks, appreciate it - is the ensemble significantly more accurate, or are the single models a decent replacement?

zubatyuk commented 9 months ago

The compiled ensembled models return forces, but do not keep graph to calculate higher order derivatives. This is for computational efficiency. I added the code of ensembled modes, if you need Hessian, you can re-compile ensemble with detach=False. https://github.com/isayevlab/AIMNet2/commit/fc671d8747ccb84a8630339bc59fc99fbbfba5a5