openmm / openmm-torch

OpenMM plugin to define forces with neural networks
182 stars 24 forks source link

OpenMMException when using TorchForce #137

Closed hongshuh closed 7 months ago

hongshuh commented 8 months ago

I am using TorchForce to simulate a box with 258 water molecule, the torch model have correct energy output shape of 1, but when I add TorchForce into the system, there is an error of tensor shape.

The code is following

BOX_SCALE = 2
DT = 2

waterbox = testsystems.WaterBox(
    box_edge=2 * unit.nanometers,
    model='tip3p')
[topology, system, positions] = [waterbox.topology, waterbox.system, waterbox.positions]

p_num = positions.shape[0] // 3
timestep = DT * unit.femtoseconds
temperature = 300 * unit.kelvin
chain_length = 10
friction = 1. / unit.picosecond
num_mts = 5
num_yoshidasuzuki = 5

integrator = integrators.NoseHooverChainVelocityVerletIntegrator(system,
                                                                    temperature,
                                                                    friction,
                                                                    timestep, chain_length, num_mts, num_yoshidasuzuki)
for f in system.getForces():
    print(f)

while system.getNumForces() > 0:
    system.removeForce(0)
    print('Remove forces in the system')
force = TorchForce('model.pt')
force.setOutputsForces(False)
system.addForce(force)
simulation = Simulation(topology, system, integrator)
simulation.context.setPositions(positions)
simulation.context.setVelocitiesToTemperature(temperature)
simulation.minimizeEnergy(tolerance=1*unit.kilojoule/unit.mole)

And here is the error message

Traceback (most recent call last):
  File "/home/hongshuh/Denoise-Pretrain-ML-Potential/run_nnp_water.py", line 79, in <module>
    simulation.minimizeEnergy(tolerance=1*unit.kilojoule/unit.mole)
  File "/home/hongshuh/anaconda3/envs/openmm-ml/lib/python3.9/site-packages/openmm/app/simulation.py", line 137, in minimizeEnergy
    mm.LocalEnergyMinimizer.minimize(self.context, tolerance, maxIterations)
  File "/home/hongshuh/anaconda3/envs/openmm-ml/lib/python3.9/site-packages/openmm/openmm.py", line 17208, in minimize
    return _openmm.LocalEnergyMinimizer_minimize(context, tolerance, maxIterations)
openmm.OpenMMException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The size of tensor a (20898) must match the size of tensor b (3) at non-singleton dimension 1
RaulPPelaez commented 8 months ago

To me this looks as an internal error of the model. I have seen this error before, but I do not remember what the cause was exactly. The error is probably originating during Autograd backpropagation.

Could you share more details about the model and your environment (the output of conda list)?

hongshuh commented 8 months ago

I'm using an EGNN model with periodic boundary conditions to predict energy. This is the only modification from the original EGNN model. And I save the compute graph as following

force_model = NeuralPotential(config).to(device)
force_model.load_pretrained_ckpt()
module = torch.jit.trace(force_model, torch.rand(10, 3).to(device)) 
module.save('model.pt')

I've checked the output tensor is something like this tensor([[0.7504]], device='cuda:2', grad_fn=<AddmmBackward0>)

My environment can successfully run simulations with torch AN2x, so I believe the setup is correct.

# packages in environment at /home/hongshuh/anaconda3/envs/openmm-ml:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
aom                       3.5.0                h27087fc_0    conda-forge
appdirs                   1.4.4                    pypi_0    pypi
ase                       3.22.1             pyhd8ed1ab_1    conda-forge
asttokens                 2.2.1                    pypi_0    pypi
astunparse                1.6.3              pyhd8ed1ab_0    conda-forge
atk-1.0                   2.38.0               hd4edc92_1    conda-forge
aws-c-auth                0.7.3                he2921ad_3    conda-forge
aws-c-cal                 0.6.2                hc309b26_0    conda-forge
aws-c-common              0.9.0                hd590300_0    conda-forge
aws-c-compression         0.2.17               h4d4d85c_2    conda-forge
aws-c-event-stream        0.3.2                h2e3709c_0    conda-forge
aws-c-http                0.7.12               hc865f51_1    conda-forge
aws-c-io                  0.13.32              h019f825_2    conda-forge
aws-c-mqtt                0.9.5                h3a0376c_1    conda-forge
aws-c-s3                  0.3.14               h1678ad6_3    conda-forge
aws-c-sdkutils            0.1.12               h4d4d85c_1    conda-forge
aws-checksums             0.1.17               h4d4d85c_1    conda-forge
aws-crt-cpp               0.23.0               h40cdbb9_5    conda-forge
aws-sdk-cpp               1.10.57             h6f6b8fa_21    conda-forge
backcall                  0.2.0                    pypi_0    pypi
black                     23.7.0                   pypi_0    pypi
blas                      2.117                  openblas    conda-forge
blas-devel                3.9.0           17_linux64_openblas    conda-forge
blinker                   1.6.2              pyhd8ed1ab_0    conda-forge
blosc                     1.21.4               h0f2a231_0    conda-forge
brotli                    1.0.9                h166bdaf_9    conda-forge
brotli-bin                1.0.9                h166bdaf_9    conda-forge
brotli-python             1.0.9            py39h5a03fae_9    conda-forge
brunsli                   0.1                  h9c3ff4c_0    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
c-ares                    1.19.1               hd590300_0    conda-forge
c-blosc2                  2.10.2               hb4ffafa_0    conda-forge
ca-certificates           2023.11.17           hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
cairo                     1.16.0            h0c91306_1017    conda-forge
captum                    0.6.0              pyhd8ed1ab_0    conda-forge
certifi                   2023.11.17         pyhd8ed1ab_0    conda-forge
cftime                    1.6.2            py39h2ae25f5_1    conda-forge
charls                    2.4.2                h59595ed_0    conda-forge
charset-normalizer        3.2.0              pyhd8ed1ab_0    conda-forge
click                     8.1.7           unix_pyh707e725_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
comm                      0.2.0                    pypi_0    pypi
contourpy                 1.1.0            py39h7633fee_0    conda-forge
coverage                  7.3.0            py39hd1e30aa_0    conda-forge
cuda-version              11.8                 h70ddcb2_2    conda-forge
cudatoolkit               11.8.0              h4ba93d1_12    conda-forge
cudnn                     8.8.0.121            h838ba91_2    conda-forge
cutlass                   2.9.1                hed8a83a_0    conda-forge
cycler                    0.11.0             pyhd8ed1ab_0    conda-forge
dav1d                     1.2.1                hd590300_0    conda-forge
debugpy                   1.8.0                    pypi_0    pypi
decorator                 5.1.1                    pypi_0    pypi
diffusers                 0.25.1                   pypi_0    pypi
exceptiongroup            1.1.3              pyhd8ed1ab_0    conda-forge
executing                 1.2.0                    pypi_0    pypi
expat                     2.5.0                hcb278e6_1    conda-forge
filelock                  3.12.3             pyhd8ed1ab_0    conda-forge
flask                     2.3.3              pyhd8ed1ab_0    conda-forge
font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
font-ttf-ubuntu           0.83                 hab24e00_0    conda-forge
fontconfig                2.14.2               h14ed4e7_0    conda-forge
fonts-conda-ecosystem     1                             0    conda-forge
fonts-conda-forge         1                             0    conda-forge
fonttools                 4.42.1           py39hd1e30aa_0    conda-forge
freetype                  2.12.1               hca18f0e_1    conda-forge
fribidi                   1.0.10               h36c2ea0_0    conda-forge
fsspec                    2023.12.2                pypi_0    pypi
gdk-pixbuf                2.42.10              h6b639ba_2    conda-forge
gettext                   0.21.1               h27087fc_0    conda-forge
giflib                    5.2.1                h0b41bf4_3    conda-forge
gmp                       6.2.1                h58526e2_0    conda-forge
gmpy2                     2.1.2            py39h376b7d2_1    conda-forge
graphite2                 1.3.13            h58526e2_1001    conda-forge
graphviz                  8.1.0                h28d9a01_0    conda-forge
gtk2                      2.24.33              h90689f9_2    conda-forge
gts                       0.7.6                h977cf35_4    conda-forge
h5py                      3.9.0           nompi_py39h680ca82_101    conda-forge
harfbuzz                  8.1.1                h3d44ed6_1    conda-forge
hdf4                      4.2.15               h501b40f_6    conda-forge
hdf5                      1.14.1          nompi_h4f84152_100    conda-forge
huggingface-hub           0.20.3                   pypi_0    pypi
icu                       73.2                 h59595ed_0    conda-forge
idna                      3.4                pyhd8ed1ab_0    conda-forge
imagecodecs               2023.8.12        py39he027151_0    conda-forge
imageio                   2.31.1             pyh24c5eb1_0    conda-forge
importlib-metadata        6.8.0              pyha770c72_0    conda-forge
importlib-resources       6.0.1              pyhd8ed1ab_0    conda-forge
importlib_metadata        6.8.0                hd8ed1ab_0    conda-forge
importlib_resources       6.0.1              pyhd8ed1ab_0    conda-forge
iniconfig                 2.0.0              pyhd8ed1ab_0    conda-forge
ipykernel                 6.27.1                   pypi_0    pypi
ipython                   8.14.0                   pypi_0    pypi
isodate                   0.6.1              pyhd8ed1ab_0    conda-forge
itsdangerous              2.1.2              pyhd8ed1ab_0    conda-forge
jax                       0.4.14             pyhd8ed1ab_1    conda-forge
jaxlib                    0.4.14          cuda112py39h15d8236_201    conda-forge
jedi                      0.19.0                   pypi_0    pypi
jinja2                    3.1.2              pyhd8ed1ab_1    conda-forge
joblib                    1.3.2              pyhd8ed1ab_0    conda-forge
jupyter-client            8.6.0                    pypi_0    pypi
jupyter-core              5.5.0                    pypi_0    pypi
jxrlib                    1.1                  h7f98852_2    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.5            py39h7633fee_0    conda-forge
krb5                      1.21.2               h659d440_0    conda-forge
lark-parser               0.12.0             pyhd8ed1ab_0    conda-forge
lazy_loader               0.3                pyhd8ed1ab_0    conda-forge
lcms2                     2.15                 haa2dc70_1    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libabseil                 20230125.3      cxx17_h59595ed_0    conda-forge
libaec                    1.0.6                hcb278e6_1    conda-forge
libavif                   0.11.1               h8182462_2    conda-forge
libblas                   3.9.0           17_linux64_openblas    conda-forge
libbrotlicommon           1.0.9                h166bdaf_9    conda-forge
libbrotlidec              1.0.9                h166bdaf_9    conda-forge
libbrotlienc              1.0.9                h166bdaf_9    conda-forge
libcblas                  3.9.0           17_linux64_openblas    conda-forge
libcurl                   8.2.1                hca28451_0    conda-forge
libdeflate                1.18                 h0b41bf4_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 h516909a_1    conda-forge
libexpat                  2.5.0                hcb278e6_1    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.1.0               he5830b7_0    conda-forge
libgd                     2.3.3                h74d50f4_7    conda-forge
libgfortran-ng            13.1.0               h69a702a_0    conda-forge
libgfortran5              13.1.0               h15d22d2_0    conda-forge
libglib                   2.76.4               hebfc3b9_0    conda-forge
libgomp                   13.1.0               he5830b7_0    conda-forge
libgrpc                   1.56.2               h3905398_1    conda-forge
libhwloc                  2.9.2           default_h554bfaf_1009    conda-forge
libiconv                  1.17                 h166bdaf_0    conda-forge
libjpeg-turbo             2.1.5.1              h0b41bf4_0    conda-forge
liblapack                 3.9.0           17_linux64_openblas    conda-forge
liblapacke                3.9.0           17_linux64_openblas    conda-forge
libllvm14                 14.0.6               hcd5def8_4    conda-forge
libmagma                  2.7.1                h09159a4_4    conda-forge
libmagma_sparse           2.7.1                hc72dce7_4    conda-forge
libmlir14                 14.0.6               he0ac6c6_0    conda-forge
libnetcdf                 4.9.2           nompi_h7e745eb_109    conda-forge
libnghttp2                1.52.0               h61bc06f_0    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libopenblas               0.3.23          pthreads_h80387f5_0    conda-forge
libpng                    1.6.39               h753d276_0    conda-forge
libprotobuf               4.23.3               hd1fb520_0    conda-forge
librsvg                   2.56.3               h98fae49_0    conda-forge
libsqlite                 3.43.0               h2797004_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx-ng              13.1.0               hfd8a6a1_0    conda-forge
libtiff                   4.5.1                h8b53f26_1    conda-forge
libtool                   2.4.7                h27087fc_0    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libuv                     1.44.2               hd590300_1    conda-forge
libwebp                   1.3.1                hbf2b3c1_0    conda-forge
libwebp-base              1.3.1                hd590300_0    conda-forge
libxcb                    1.15                 h0b41bf4_0    conda-forge
libxml2                   2.11.5               h232c23b_1    conda-forge
libzip                    1.10.1               h2629f0a_0    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
libzopfli                 1.0.3                h9c3ff4c_0    conda-forge
lit                       16.0.6             pyh1a96a4e_2    conda-forge
llvm-openmp               16.0.6               h4dfa4b3_0    conda-forge
llvmlite                  0.40.1           py39h174d805_0    conda-forge
lz4-c                     1.9.4                hcb278e6_0    conda-forge
lzo                       2.10              h516909a_1000    conda-forge
magma                     2.7.1                ha770c72_4    conda-forge
markupsafe                2.1.3            py39hd1e30aa_0    conda-forge
matplotlib-base           3.7.2            py39h0126182_0    conda-forge
matplotlib-inline         0.1.6                    pypi_0    pypi
mdtraj                    1.9.9            py39h031bd0f_0    conda-forge
mkl                       2022.2.1         h84fe81f_16997    conda-forge
ml_dtypes                 0.2.0            py39h40cae4c_1    conda-forge
mpc                       1.3.1                hfe3b2da_0    conda-forge
mpfr                      4.2.0                hb012696_0    conda-forge
mpiplus                   v0.0.2             pyhd8ed1ab_0    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
mypy-extensions           1.0.0                    pypi_0    pypi
nccl                      2.18.5.1             h0800d71_0    conda-forge
ncurses                   6.4                  hcb278e6_0    conda-forge
nest-asyncio              1.5.8                    pypi_0    pypi
netcdf4                   1.6.4           nompi_py39h4218a78_101    conda-forge
networkx                  3.1                pyhd8ed1ab_0    conda-forge
nnpops                    0.6             cuda112py39h3ef88e3_1    conda-forge
nomkl                     3.0                           0  
normflows                 1.7.3                    pypi_0    pypi
nose                      1.3.7                   py_1006    conda-forge
numba                     0.57.1           py39hb75a051_0    conda-forge
numexpr                   2.8.4           py39h8825413_101    conda-forge
numpy                     1.22.4                   pypi_0    pypi
ocl-icd                   2.3.1                h7f98852_0    conda-forge
ocl-icd-system            1.0.0                         1    conda-forge
openblas                  0.3.23          pthreads_h855a84d_0    conda-forge
openjpeg                  2.5.0                hfec8fc6_2    conda-forge
openmm                    8.0.0            py39h7d85326_1    conda-forge
openmm-ml                 1.1                pyhd8ed1ab_0    conda-forge
openmm-torch              1.1             cuda112py39h2f84e51_0    conda-forge
openmmtools               0.23.1             pyhd8ed1ab_0    conda-forge
openssl                   3.2.0                hd590300_1    conda-forge
opt_einsum                3.3.0              pyhd8ed1ab_1    conda-forge
packaging                 23.1               pyhd8ed1ab_0    conda-forge
pandas                    2.1.0            py39hddac248_0    conda-forge
pango                     1.50.14              ha41ecd1_2    conda-forge
parallel-hashmap          1.33                 hca92ed8_0    conda-forge
parso                     0.8.3                    pypi_0    pypi
pathspec                  0.11.2                   pypi_0    pypi
patsy                     0.5.3              pyhd8ed1ab_0    conda-forge
pcre2                     10.40                hc3806b6_0    conda-forge
pdbfixer                  1.9                pyh1a96a4e_0    conda-forge
pexpect                   4.8.0                    pypi_0    pypi
pickleshare               0.7.5                    pypi_0    pypi
pillow                    10.0.0           py39haaeba84_0    conda-forge
pip                       23.2.1             pyhd8ed1ab_0    conda-forge
pixman                    0.40.0               h36c2ea0_0    conda-forge
platformdirs              3.10.0             pyhd8ed1ab_0    conda-forge
pluggy                    1.3.0              pyhd8ed1ab_0    conda-forge
pooch                     1.7.0              pyha770c72_3    conda-forge
prompt-toolkit            3.0.39                   pypi_0    pypi
psutil                    5.9.5            py39h72bdee0_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0                    pypi_0    pypi
pure-eval                 0.2.2                    pypi_0    pypi
py-cpuinfo                9.0.0              pyhd8ed1ab_0    conda-forge
pyg-lib                   0.2.0           cuda112py39hcc031ee_2    conda-forge
pygments                  2.16.1                   pypi_0    pypi
pymbar                    4.0.2                hf3d152e_0    conda-forge
pymbar-core               4.0.2            py39h0f8d45d_0    conda-forge
pyparsing                 3.0.9              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytables                  3.8.0            py39hb8e3aad_2    conda-forge
pytest                    7.4.0              pyhd8ed1ab_0    conda-forge
pytest-cov                4.1.0              pyhd8ed1ab_0    conda-forge
python                    3.9.18          h0755675_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python-tzdata             2023.3             pyhd8ed1ab_0    conda-forge
python-wget               3.2                        py_0    conda-forge
python_abi                3.9                      3_cp39    conda-forge
pytorch                   2.0.0           cuda112py39ha9a2dba_301    conda-forge
pytorch_cluster           1.6.1            py39hcc031ee_2    conda-forge
pytorch_geometric         2.3.1              pyhd8ed1ab_0    conda-forge
pytorch_scatter           2.1.1           cuda112py39hcc031ee_1    conda-forge
pytorch_sparse            0.6.17           py39hcc031ee_1    conda-forge
pytz                      2023.3             pyhd8ed1ab_0    conda-forge
pywavelets                1.4.1            py39h389d5f1_0    conda-forge
pyyaml                    6.0.1            py39hd1e30aa_0    conda-forge
pyzmq                     25.1.1                   pypi_0    pypi
rdflib                    7.0.0              pyhd8ed1ab_0    conda-forge
re2                       2023.03.02           h8c504da_0    conda-forge
readline                  8.2                  h8228510_1    conda-forge
regex                     2023.12.25               pypi_0    pypi
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
s2n                       1.3.49               h06160fa_0    conda-forge
safetensors               0.4.2                    pypi_0    pypi
scikit-image              0.21.0           py39h3d6467e_0    conda-forge
scikit-learn              1.3.0            py39hc236052_0    conda-forge
scipy                     1.11.2           py39h6183b62_0    conda-forge
seaborn                   0.13.0                   pypi_0    pypi
setuptools                65.3.0             pyhd8ed1ab_1    conda-forge
setuptools-scm            6.3.2              pyhd8ed1ab_0    conda-forge
setuptools_scm            6.3.2                hd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
snappy                    1.1.10               h9fff704_0    conda-forge
stack-data                0.6.2                    pypi_0    pypi
statsmodels               0.14.0           py39h0f8d45d_1    conda-forge
sympy                     1.12            pypyh9d50eac_103    conda-forge
tabulate                  0.9.0              pyhd8ed1ab_1    conda-forge
tbb                       2021.10.0            h00ab1b0_0    conda-forge
threadpoolctl             3.2.0              pyha21a80b_0    conda-forge
tifffile                  2023.8.30          pyhd8ed1ab_0    conda-forge
tk                        8.6.12               h27826a3_0    conda-forge
tokenize-rt               5.2.0                    pypi_0    pypi
toml                      0.10.2             pyhd8ed1ab_0    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
torch-nl                  0.3                      pypi_0    pypi
torchani                  2.2.3           cuda112py39h2d59c9c_3    conda-forge
tornado                   6.4                      pypi_0    pypi
tqdm                      4.66.1             pyhd8ed1ab_0    conda-forge
traitlets                 5.9.0                    pypi_0    pypi
trimesh                   3.23.5             pyhd8ed1ab_0    conda-forge
triton                    2.0.0           cuda112py39hafd0abc_1    conda-forge
typing-extensions         4.7.1                hd8ed1ab_0    conda-forge
typing_extensions         4.7.1              pyha770c72_0    conda-forge
tzdata                    2023c                h71feb2d_0    conda-forge
unicodedata2              15.0.0           py39hb9d737c_0    conda-forge
urllib3                   2.0.4              pyhd8ed1ab_0    conda-forge
wandb                     0.15.9                   pypi_0    pypi
wcwidth                   0.2.6                    pypi_0    pypi
werkzeug                  2.3.7              pyhd8ed1ab_0    conda-forge
wheel                     0.41.2             pyhd8ed1ab_0    conda-forge
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.1.1                hd590300_0    conda-forge
xorg-libsm                1.2.4                h7391055_0    conda-forge
xorg-libx11               1.8.6                h8ee46fc_0    conda-forge
xorg-libxau               1.0.11               hd590300_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h0b41bf4_2    conda-forge
xorg-libxrender           0.9.11               hd590300_0    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h0b41bf4_1003    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zfp                       1.0.0                h27087fc_3    conda-forge
zipp                      3.16.2             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zlib-ng                   2.0.7                h0b41bf4_0    conda-forge
zstd                      1.5.5                hfc55251_0    conda-forge
RaulPPelaez commented 8 months ago

Try to call torch.jit.script instead of trace.

hongshuh commented 8 months ago

When using torch.jit.script, I have another error, this does not show up when using trace

Traceback (most recent call last):
  File "/home/hongshuh/Denoise-Pretrain-ML-Potential/md_wrapper.py", line 103, in <module>
    module = torch.jit.script(force_model) 
  File "/home/hongshuh/anaconda3/envs/openmm-ml/lib/python3.9/site-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
  File "/home/hongshuh/anaconda3/envs/openmm-ml/lib/python3.9/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/hongshuh/anaconda3/envs/openmm-ml/lib/python3.9/site-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/home/hongshuh/anaconda3/envs/openmm-ml/lib/python3.9/site-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: 
Module 'EGNN' has no attribute '_modules' :
  File "/home/hongshuh/Denoise-Pretrain-ML-Potential/models/egnn.py", line 244
        for i in range(0, self.n_layers):
            if self.use_pbc:
                h, x, edge_attr = self._modules["gcl_%d" % i](h, edge_index, x, edge_attr=edge_attr, cell_size=cell_size, batch=batch)
                                  ~~~~~~~~~~~~~ <--- HERE
            else:
                h, x, _ = self._modules["gcl_%d" % i](h, edge_index, x, edge_attr=edge_attr)
RaulPPelaez commented 8 months ago

I am unsure this is related to openmm-torch. Does this successfully run?

force_model = NeuralPotential(config).to(device)
force_model.load_pretrained_ckpt()
force_model = torch.jit.script(force_model).to("cuda")
N=10
positions = torch.rand(N,3, dtype=torch.float32, device="cuda", requires_grad=True)
z = torch.rand(low=0, high=10, size=(N,), dtype=torch.long, device="cuda")
energy = force_model(z, pos) # Modify to call your model in the expected way
energy.backwards()
force = pos.grad
hongshuh commented 7 months ago

There will be error at this line,

force_model = torch.jit.script(force_model).to("cuda")

The error might be due to EGNN using add_modules for its convolution layer, which conflicts with torch.jit.script. Instead I run the following using trace

force_model = NeuralPotential(config).to(device)
force_model.load_pretrained_ckpt()
module = torch.jit.trace(force_model, torch.rand(10, 3).to(device)) 
N=10
pos = torch.rand(N,3, dtype=torch.float32, device="cuda", requires_grad=True)
energy = module(pos) 
energy.backwards()
force = pos.grad

the force tensor has the expected shape of [10, 3].

RaulPPelaez commented 7 months ago

Trace just runs your model with an example input and registers whats being done to those. I can see that process failing when used in conjunction with autograd and potentially dynamic shapes. Your model must be fully compatible with TorchScript to be run in openmm-torch.

hongshuh commented 7 months ago

Thanks, it might because of EGNN model itself, I will try to use other model.

RaulPPelaez commented 7 months ago

I am not familiar with this particular model or the implementation you are using. I suggest you open an issue in their repo asking for TorchScript support. This year there have been a lot of improvements in TorchScript Pytorch-side. As a last resort you could try to set up an environment with pytorch 2.1 and python 3.11.