openmm / openmm-ml

High level API for using machine learning models in OpenMM simulations
Other
83 stars 26 forks source link

Failure to remove CMAPTorsionForce from the ML region #68

Closed JMorado closed 9 months ago

JMorado commented 10 months ago

Hi,

When removing all bonded interactions between atoms within the ML region, it seems that the case involving CMAPs is not considered. This situation arises, for example, when using the Amber ff19SB force field. If I execute the following script:

diala_solvated.zip

import openmm
import openmm.unit as unit
import openmm.app as app
from openmmml import MLPotential

inpcrd = app.AmberInpcrdFile("diala_solvated.inpcrd")
prmtop = app.AmberPrmtopFile("diala_solvated.prmtop")

mm_system = prmtop.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=0.9*unit.nanometer, constraints=app.HBonds)
ml_atoms = [atom.index for res in list(prmtop.topology.residues())[:3] for atom in list(res.atoms())]

potential = MLPotential("ani2x")
ml_system = potential.createMixedSystem(prmtop.topology, mm_system, ml_atoms)

I get the the following KeyError:

Traceback (most recent call last):
  File "/home/joaomorado/software/test/test/ani_simulation.py", line 23, in <module>
    ml_system = potential.createMixedSystem(prmtop.topology, mm_system, ml_atoms)
  File "/home/joaomorado/software/openmm/openmmml/mlpotential.py", line 247, in createMixedSystem
    newSystem = self._removeBonds(system, atoms, True, removeConstraints)
  File "/home/joaomorado/software/openmm-ml/openmmml/mlpotential.py", line 390, in _removeBonds
    torsionAtoms = [int(torsion.attrib[p]) for p in ('p1', 'p2', 'p3', 'p4')]
  File "/home/joaomorado/software/openmm-ml/openmmml/mlpotential.py", line 390, in <listcomp>
    torsionAtoms = [int(torsion.attrib[p]) for p in ('p1', 'p2', 'p3', 'p4')]
KeyError: 'p1'

In principle, resolving this issue should involve changing from:

https://github.com/openmm/openmm-ml/blob/bc414b4e80d9ec0bbae3c356eeb66d5a82a19352/openmmml/mlpotential.py#L388-L392

to:

for torsions in root.findall('./Forces/Force/Torsions'):
    for torsion in torsions.findall('Torsion'):
        torsionAtomsLabels =  ('p1', 'p2', 'p3', 'p4') if all(p in torsion.attrib for p in ('p1', 'p2', 'p3', 'p4')) else ('a1', 'a2', 'a3', 'a4', 'b1', 'b2', 'b3', 'b4')
        torsionAtoms = [int(torsion.attrib[p]) for p in torsionAtomsLabels]
        if shouldRemove(torsionAtoms):
            torsions.remove(torsion)

I am happy to do a PR with this fix, but I am wondering if there are any other cases to consider.

Many thanks.

peastman commented 10 months ago

A PR would be great. You can probably simplify all(p in torsion.attrib for p in ('p1', 'p2', 'p3', 'p4')) to just 'p1' in torsion.attrib. Checking just one is enough to tell us which labels to use.

Thanks for catching this!