This PR adds the cell key to the input dictionary passed to the MACE model. cell is not used for calculating the energies in MACE models, and therefore, with the default torch._C._set_graph_executor_optimize(True) settings, the current implementation runs just fine. However, when combining openmm-ml with other codes that internally set this flag to False (as was my case), TorchScript complains about the missing key:
simulation.context.getState(getEnergy=True)
File "/home/joaomorado/miniconda3/envs/fes-ml-aev/lib/python3.11/site-packages/openmm/openmm.py", line 7977, in getState
state = _openmm.Context_getState(self, types, enforcePeriodicBox, groups_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
openmm.OpenMMException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__/openmmml/models/macepotential.py", line 41, in forward
inputDict = {"ptr": ptr, "node_attrs": node_attrs, "batch": batch, "pbc": pbc, "positions": positions0, "edge_index": edgeIndex, "shifts": shifts}
model = self.model
_5 = (model).forward(inputDict, False, False, False, False, False, False, )
~~~~~~~~~~~~~~ <--- HERE
returnEnergyType = self.returnEnergyType
energy = _5[returnEnergyType]
File "code/__torch__/mace/modules/models.py", line 96, in forward
total_energy = torch.add(e0, inter_e)
node_energy = torch.add(node_e0, node_inter_es0)
_28 = _3(inter_e, data["positions"], displacement0, data["cell"], training, compute_force, compute_virials, compute_stress, compute_hessian, )
~~~~~~~~~~~~ <--- HERE
forces, virials, stress, hessian, = _28
output = annotate(Dict[str, Optional[Tensor]], {"energy": total_energy, "node_energy": node_energy, "interaction_energy": inter_e, "forces": forces, "virials": virials, "stress": stress, "hessian": hessian, "displacement": displacement0, "node_feats": node_feats_out})
Traceback of TorchScript, original code (most recent call last):
File "/home/joaomorado/repos/openmm-ml-myfork/openmmml/models/macepotential.py", line 377, in forward
# Predict the energy.
energy = self.model(inputDict, compute_force=False)[
~~~~~~~~~~ <--- HERE
self.returnEnergyType
]
File "/home/joaomorado/miniconda3/envs/fes-ml-aev/lib/python3.11/site-packages/mace/modules/models.py", line 401, in forward
positions=data["positions"],
displacement=displacement,
cell=data["cell"],
~~~~~~~~~~~ <--- HERE
training=training,
compute_force=compute_force,
RuntimeError: KeyError: cell
The following code reproduces the issue:
import torch, torch._C
import openmm as mm
import openmm.app as app
from openmmml import MLPotential
# Disable the graph executor optimization
torch._C._set_graph_executor_optimize(False)
# Standard script
pdb = app.PDBFile('alanine-dipeptide-explicit.pdb')
ff = app.ForceField('amber14-all.xml', 'amber14/tip3pfb.xml')
mmSystem = ff.createSystem(pdb.topology, nonbondedMethod=app.PME)
potential = MLPotential('mace-off23-small')
system = potential.createSystem(pdb.topology)
mlAtoms = [a.index for a in next(pdb.topology.chains()).atoms()]
mixedSystem = potential.createMixedSystem(pdb.topology, mmSystem, mlAtoms, interpolate=True)
platform = mm.Platform.getPlatformByName("CUDA")
mixedContext = mm.Context(mixedSystem, mm.VerletIntegrator(0.001), platform)
simulation = app.Simulation(pdb.topology, system, mm.VerletIntegrator(0), platform)
simulation.context.setPositions(pdb.positions)
simulation.context.getState(getEnergy=True)
Additionally, this PR changes the code so that inputDict is recreated at each forward pass. Currently, MACE internally does not modify this dictionary, and everything works, but it's better to prevent future issues if MACE's implementation changes. NequIP models, for example, modify the passed dictionary, leading to issues after the first forward pass.
This PR adds the
cell
key to the input dictionary passed to the MACE model.cell
is not used for calculating the energies in MACE models, and therefore, with the defaulttorch._C._set_graph_executor_optimize(True)
settings, the current implementation runs just fine. However, when combining openmm-ml with other codes that internally set this flag toFalse
(as was my case), TorchScript complains about the missing key:The following code reproduces the issue:
Additionally, this PR changes the code so that
inputDict
is recreated at each forward pass. Currently, MACE internally does not modify this dictionary, and everything works, but it's better to prevent future issues if MACE's implementation changes. NequIP models, for example, modify the passed dictionary, leading to issues after the first forward pass.