openmm / openmm-ml

High level API for using machine learning models in OpenMM simulations
Other
83 stars 25 forks source link

Update inputDict of MACEForce #86

Closed JMorado closed 3 months ago

JMorado commented 3 months ago

This PR adds the cell key to the input dictionary passed to the MACE model. cell is not used for calculating the energies in MACE models, and therefore, with the default torch._C._set_graph_executor_optimize(True) settings, the current implementation runs just fine. However, when combining openmm-ml with other codes that internally set this flag to False (as was my case), TorchScript complains about the missing key:

    simulation.context.getState(getEnergy=True)
  File "/home/joaomorado/miniconda3/envs/fes-ml-aev/lib/python3.11/site-packages/openmm/openmm.py", line 7977, in getState
    state = _openmm.Context_getState(self, types, enforcePeriodicBox, groups_mask)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
openmm.OpenMMException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/openmmml/models/macepotential.py", line 41, in forward
    inputDict = {"ptr": ptr, "node_attrs": node_attrs, "batch": batch, "pbc": pbc, "positions": positions0, "edge_index": edgeIndex, "shifts": shifts}
    model = self.model
    _5 = (model).forward(inputDict, False, False, False, False, False, False, )
          ~~~~~~~~~~~~~~ <--- HERE
    returnEnergyType = self.returnEnergyType
    energy = _5[returnEnergyType]
  File "code/__torch__/mace/modules/models.py", line 96, in forward
    total_energy = torch.add(e0, inter_e)
    node_energy = torch.add(node_e0, node_inter_es0)
    _28 = _3(inter_e, data["positions"], displacement0, data["cell"], training, compute_force, compute_virials, compute_stress, compute_hessian, )
                                                        ~~~~~~~~~~~~ <--- HERE
    forces, virials, stress, hessian, = _28
    output = annotate(Dict[str, Optional[Tensor]], {"energy": total_energy, "node_energy": node_energy, "interaction_energy": inter_e, "forces": forces, "virials": virials, "stress": stress, "hessian": hessian, "displacement": displacement0, "node_feats": node_feats_out})

Traceback of TorchScript, original code (most recent call last):
  File "/home/joaomorado/repos/openmm-ml-myfork/openmmml/models/macepotential.py", line 377, in forward

                # Predict the energy.
                energy = self.model(inputDict, compute_force=False)[
                         ~~~~~~~~~~ <--- HERE
                    self.returnEnergyType
                ]
  File "/home/joaomorado/miniconda3/envs/fes-ml-aev/lib/python3.11/site-packages/mace/modules/models.py", line 401, in forward
            positions=data["positions"],
            displacement=displacement,
            cell=data["cell"],
                 ~~~~~~~~~~~ <--- HERE
            training=training,
            compute_force=compute_force,
RuntimeError: KeyError: cell

The following code reproduces the issue:

import torch, torch._C
import openmm as mm
import openmm.app as app
from openmmml import MLPotential

# Disable the graph executor optimization
torch._C._set_graph_executor_optimize(False)

# Standard script
pdb = app.PDBFile('alanine-dipeptide-explicit.pdb')
ff = app.ForceField('amber14-all.xml', 'amber14/tip3pfb.xml')
mmSystem = ff.createSystem(pdb.topology, nonbondedMethod=app.PME)
potential = MLPotential('mace-off23-small')
system = potential.createSystem(pdb.topology)
mlAtoms = [a.index for a in next(pdb.topology.chains()).atoms()]
mixedSystem = potential.createMixedSystem(pdb.topology, mmSystem, mlAtoms, interpolate=True)
platform = mm.Platform.getPlatformByName("CUDA") 
mixedContext = mm.Context(mixedSystem, mm.VerletIntegrator(0.001), platform)
simulation = app.Simulation(pdb.topology, system, mm.VerletIntegrator(0), platform)
simulation.context.setPositions(pdb.positions)
simulation.context.getState(getEnergy=True)

Additionally, this PR changes the code so that inputDict is recreated at each forward pass. Currently, MACE internally does not modify this dictionary, and everything works, but it's better to prevent future issues if MACE's implementation changes. NequIP models, for example, modify the passed dictionary, leading to issues after the first forward pass.

peastman commented 3 months ago

Looks good, thanks!