openmm / openmm-torch

OpenMM plugin to define forces with neural networks
181 stars 23 forks source link

CUDA out of memory #99

Closed Yangxinsix closed 1 year ago

Yangxinsix commented 1 year ago

I'm trying to follow the NNOPs tutorial. But the tutorial fails at the second step; it installs nothing and always makes my colab crashed.

Then I tried to follow this tutorial on my own laptop to create a pytorch force field by myself. Fortunately, the installation finally works. But the simulation fails with the CUDA out of memory error, without running for even a single step.

I tried to use an extremely small model with only ~100 parameters and tested if there is any accumulated computational graph by running it multiple times (more than 100). But it still gives me this error.

I'm not sure if there is any memory leaking problem in the plugin or if it is just the required memory of openmm is too large. Could you help me check that?

Thanks a lot.

This is the error information:

Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__.py", line 25, in forward
    input_dict = {"pairs": _2, "n_diff": _3, "n_dist": _4, "num_atoms": _5, "num_pairs": _6, "elems": elems}
    model = self.model
    output = (model).forward(input_dict, True, )
              ~~~~~~~~~~~~~~ <--- HERE
    _7 = (torch.detach(output["energy"]), torch.detach(output["forces"]))
    return _7
  File "code/__torch__/PaiNN/model.py", line 53, in forward
    _11 = getattr(update_layers, "1")
    _21 = getattr(update_layers, "2")
    _8 = (_00).forward(node_scalar, node_vector, edge0, edge_diff, edge_dist, )
          ~~~~~~~~~~~~ <--- HERE
...
    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)
               ~~~~~~~~ <--- HERE
RuntimeError: CUDA out of memory. Tried to allocate 4.96 GiB (GPU 0; 4.00 GiB total capacity; 903.65 MiB already allocated; 1.89 GiB free; 1.27 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

This is my conda environment:

# packages in environment at /home/xinyang/miniconda3/envs/openmm:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
asap3                     3.12.12                  pypi_0    pypi
ase                       3.22.1                   pypi_0    pypi
asttokens                 2.2.1              pyhd8ed1ab_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                pyhd8ed1ab_3    conda-forge
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
ca-certificates           2022.12.7            ha878542_0    conda-forge
cffi                      1.15.1          py311h409f033_3    conda-forge
comm                      0.1.2              pyhd8ed1ab_0    conda-forge
contourpy                 1.0.7                    pypi_0    pypi
cudatoolkit               11.8.0              h37601d7_11    conda-forge
cudnn                     8.4.1.50             hed8a83a_0    conda-forge
cycler                    0.11.0                   pypi_0    pypi
debugpy                   1.6.6           py311hcafe171_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
executing                 1.2.0              pyhd8ed1ab_0    conda-forge
fonttools                 4.39.2                   pypi_0    pypi
icu                       70.1                 h27087fc_0    conda-forge
importlib-metadata        6.0.0              pyha770c72_0    conda-forge
importlib_metadata        6.0.0                hd8ed1ab_0    conda-forge
ipykernel                 6.21.3             pyh210e3f2_0    conda-forge
ipython                   8.11.0             pyh41d4057_0    conda-forge
jedi                      0.18.2             pyhd8ed1ab_0    conda-forge
jupyter_client            8.0.3              pyhd8ed1ab_0    conda-forge
jupyter_core              5.3.0           py311h38be061_0    conda-forge
kiwisolver                1.4.4                    pypi_0    pypi
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
libblas                   3.9.0           16_linux64_openblas    conda-forge
libcblas                  3.9.0           16_linux64_openblas    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 12.2.0              h65d4601_19    conda-forge
libgfortran-ng            12.2.0              h69a702a_19    conda-forge
libgfortran5              12.2.0              h337968e_19    conda-forge
libhwloc                  2.9.0                hd6dc26d_0    conda-forge
libiconv                  1.17                 h166bdaf_0    conda-forge
liblapack                 3.9.0           16_linux64_openblas    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libopenblas               0.3.21          pthreads_h78a6416_3    conda-forge
libprotobuf               3.21.12              h3eb15da_0    conda-forge
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libsqlite                 3.40.0               h753d276_0    conda-forge
libstdcxx-ng              12.2.0              h46fd767_19    conda-forge
libuuid                   2.32.1            h7f98852_1000    conda-forge
libxml2                   2.10.3               hca2bb57_3    conda-forge
libzlib                   1.2.13               h166bdaf_4    conda-forge
llvm-openmp               15.0.7               h0cdce71_0    conda-forge
magma                     2.6.2                hc72dce7_0    conda-forge
matplotlib                3.7.1                    pypi_0    pypi
matplotlib-inline         0.1.6              pyhd8ed1ab_0    conda-forge
mkl                       2022.2.1         h84fe81f_16997    conda-forge
nccl                      2.14.3.1             h0800d71_0    conda-forge
ncurses                   6.3                  h27087fc_1    conda-forge
nest-asyncio              1.5.6              pyhd8ed1ab_0    conda-forge
ninja                     1.11.1               h924138e_0    conda-forge
numpy                     1.24.2          py311h8e6699e_0    conda-forge
ocl-icd                   2.3.1                h7f98852_0    conda-forge
ocl-icd-system            1.0.0                         1    conda-forge
openmm                    8.0.0           py311h59c6c42_0    conda-forge
openmm-torch              1.0             cuda112py311hfd30f1a_0    conda-forge
openssl                   3.1.0                h0b41bf4_0    conda-forge
packaging                 23.0               pyhd8ed1ab_0    conda-forge
painn                     1.0.0                    pypi_0    pypi
parso                     0.8.3              pyhd8ed1ab_0    conda-forge
pexpect                   4.8.0              pyh1a96a4e_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    9.4.0                    pypi_0    pypi
pip                       23.0.1             pyhd8ed1ab_0    conda-forge
platformdirs              3.1.1              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.38             pyha770c72_0    conda-forge
prompt_toolkit            3.0.38               hd8ed1ab_0    conda-forge
psutil                    5.9.4           py311hd4cff14_0    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pygments                  2.14.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.0.9                    pypi_0    pypi
python                    3.11.0          he550d4f_1_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python_abi                3.11                    3_cp311    conda-forge
pytorch                   1.13.1          cuda112py311h13fee9e_200    conda-forge
pytorch-gpu               1.13.1          cuda112py311h7c68dbd_200    conda-forge
pyzmq                     25.0.1          py311hd6ccaeb_0    conda-forge
readline                  8.1.2                h0f457ee_0    conda-forge
scipy                     1.10.1                   pypi_0    pypi
setuptools                67.6.0             pyhd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
tbb                       2021.8.0             hf52228f_0    conda-forge
tk                        8.6.12               h27826a3_0    conda-forge
toml                      0.10.2                   pypi_0    pypi
tornado                   6.2             py311hd4cff14_1    conda-forge
traitlets                 5.9.0              pyhd8ed1ab_0    conda-forge
typing-extensions         4.5.0                hd8ed1ab_0    conda-forge
typing_extensions         4.5.0              pyha770c72_0    conda-forge
tzdata                    2022g                h191b570_0    conda-forge
wcwidth                   0.2.6              pyhd8ed1ab_0    conda-forge
wheel                     0.40.0             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zeromq                    4.3.4                h9c3ff4c_1    conda-forge
zipp                      3.15.0             pyhd8ed1ab_0    conda-forge
peastman commented 1 year ago

Which tutorial are you running? Can you provide the exact code that leads to the error?

Yangxinsix commented 1 year ago
from openmmtorch import TorchForce
import sys
import openmm
from openmm import LangevinMiddleIntegrator
from openmm.app import Simulation, StateDataReporter, Topology, Modeller
from openmm import unit
from openmm.app.element import Element

import torch
from torch import nn
from PaiNN.data import NeighborList
from PaiNN.model import PainnModel
from ase.io import read, write
import numpy as np

# create simulation system
atoms = read('work/dataset/corrected_ads_images.traj', 100)

pos = atoms.get_positions() / 10
box_vectors = atoms.get_cell() / 10
elements = atoms.get_chemical_symbols()

# Create a topology object
topology = Topology()

# Add atoms to the topology
chain = topology.addChain()
res = topology.addResidue("mace_system", chain)
for i, (element, position) in enumerate(zip(elements, pos)):
    e = Element.getBySymbol(element)
    topology.addAtom(str(i), e, res)
# if there is a periodic box specified add it to the Topology
if np.all(atoms.pbc):
    topology.setPeriodicBoxVectors(vectors=box_vectors)

# Create a modeller object
modeller = Modeller(topology, pos)

# Create a system object
system = openmm.System()
if topology.getPeriodicBoxVectors() is not None:
    system.setDefaultPeriodicBoxVectors(*topology.getPeriodicBoxVectors())
for atom in topology.atoms():
    if atom.element is None:
        system.addParticle(0)
    else:
        system.addParticle(atom.element.mass)

# Wrapper model for simulation
class PainnOpenmm(nn.Module):
    def __init__(self, elements: torch.Tensor, model: PainnModel) -> None:
        super().__init__()
        self.neigh_list = NeighborList(model.cutoff)
        self.model = model
        self.register_buffer('elems', elements)

    def forward(self, positions: torch.Tensor, cell: torch.Tensor):
        print(f'Device of positions: {positions.device}')
        pairs, pair_diff, pair_dist = self.neigh_list(positions, cell)
        input_dict = {
            'pairs': pairs,
            'n_diff': pair_diff,
            'n_dist': pair_dist,
            'num_atoms': torch.tensor([positions.shape[0]], dtype=pairs.dtype, device=pairs.device),
            'num_pairs': torch.tensor([pairs.shape[0]], dtype=pairs.dtype, device=pairs.device),
            'elems': self.elems,
        }
        output = self.model(input_dict)

        return (output['energy'], output['forces'])

# load trained model
state_dict = torch.load('/work3/xinyang/work/models/ads_images/128_node_3_layer.pth')
model = PainnModel(
    num_interactions=state_dict['num_layer'],
    hidden_state_size=state_dict['node_size'],
    cutoff=state_dict['cutoff'],
    normalization=False,
)
model.load_state_dict(state_dict['model'])

# model deploy
elems = torch.from_numpy(atoms.get_atomic_numbers())
positions = torch.from_numpy(atoms.get_positions()).float()
cell = torch.from_numpy(atoms.cell[:]).float()

openmm_ff = PainnOpenmm(elements=elems, model=model)
openmm_ff.cuda()
torch.jit.script(openmm_ff).save('deployed_model')

# load force field
force = TorchForce('deployed_model')
force.setUsesPeriodicBoundaryConditions(True)
force.setOutputsForces(True)
system.addForce(force)

# set up initial parameters
temperature = 298.15 * unit.kelvin
frictionCoeff = 1 / unit.picosecond
timeStep = 1 * unit.femtosecond
integrator = LangevinMiddleIntegrator(temperature, frictionCoeff, timeStep)

# setup simulations
simulation = Simulation(topology, system, integrator)
simulation.context.setPositions(modeller.getPositions())
reporter = StateDataReporter(file=sys.stdout, reportInterval=1, step=True, time=True, potentialEnergy=True, temperature=True)
simulation.reporters.append(reporter)

All above code ran successfully. And I also tested MD simulation using the model via ASE, it is absolutely fine for running more than 10 million steps. No OOM problem shows even using the model on my own laptop with a 4 GB memory GPU.

The following two lines showed the CUDA out of memory:

state = simulation.context.getState(getEnergy=True)
simulation.step(100)

I also tried to run the above code on a Tesla A100 GPU with 40 GB of memory. Now it gives the same error:

RuntimeError: CUDA out of memory. Tried to allocate 4.08 GiB (GPU 0; 39.43 GiB total capacity; 34.11 GiB already allocated; 724.31 MiB free; 37.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

So, I'm quite sure there may be some memory leaking.

sef43 commented 1 year ago

One possible reason for your error could be units: I think ASE has positions in Angstroms while in OpenMM positions will be in nanometers. I notice you do pos = atoms.get_positions() / 10 in the setup, I assume this is to turn ASE Angstroms into OpenMM nanometers? In the forward method you do no unit conversions. The positions passed by OpenMM into forward will be in nanometers. Is this what the model is expecting? or should they be converted into Angstrom? Also check the energy and force units, you may need to put conversions in the forward method. OpenMM uses kJ/mol for energy. What does the model you are using use?

(edit: I initially incorrectly wrote kcal/mol, here are OpenMM units: http://docs.openmm.org/latest/userguide/theory/01_introduction.html#units)

Yangxinsix commented 1 year ago

Thanks a lot for your explanation! The error is due to units: When using nm, the constructed neighbor list became much larger, so much more memory is requested by the model.