openmm / openmm-torch

OpenMM plugin to define forces with neural networks
179 stars 23 forks source link

How to create a CompoundThermodynamicState from a torchForce object with global parameter #147

Open xiaowei-xie2 opened 2 months ago

xiaowei-xie2 commented 2 months ago

Hi,

I would like to do a Hamiltonian REMD with custom defined states, with each state specified by a torchForce object with a different global parameter. But I am having trouble creating a CompoundThermodynamicState to be used with ReplicaExchangeSampler. I can create a GlobalParameterState, but when I use that to create CompoundThermodynamicState, it complains there is no global parameter in the system. I have no trouble doing the same thing with a MM force field. Any idea what might be going wrong?

Thank you!

Here is the structure of the code I was using:

force = TorchForce('model.pt')
force.addGlobalParameter('a', 0.5)
force.addGlobalParameter('b', 0.3)
force.setUsesPeriodicBoundaryConditions(True)

# define system
system = ...

# Remove MM constraints
while system.getNumConstraints() > 0:
  system.removeConstraint(0)

# Remove MM forces
while system.getNumForces() > 0:
  system.removeForce(0)

assert system.getNumConstraints() == 0
assert system.getNumForces() == 0

system.addForce(force)

barostat = MonteCarloBarostat(1*bar, 298.15*kelvin)
system.addForce(barostat)

class LambdaState(GlobalParameterState):
    a = GlobalParameterState.GlobalParameter('a', standard_value=1.0)
    b = GlobalParameterState.GlobalParameter('b', standard_value=1.0)

    def set_rest_parameters(self, value_a, value_b):
        """Set all defined lambda parameters to the given value.

        The undefined parameters (i.e. those being set to None) remain undefined.

        Parameters
        ----------
        new_value : float
            The new value for all defined parameters.
        """
        lambda_functions = {'a': lambda a, b : value_a,
                 'b' : lambda a, b : value_b,
                 }

        for parameter_name in self._parameters:
            if self._parameters[parameter_name] is not None:
                new_value = lambda_functions[parameter_name](a, b)
                setattr(self, parameter_name, new_value)

lambda_state = LambdaState(a=0.5, b=0.3)
print('lambda_state.a:', lambda_state.a)
print('lambda_state.b:', lambda_state.b)

thermostate = ThermodynamicState(system, temperature=298.15 * unit.kelvin)
compound_thermodynamic_state = CompoundThermodynamicState(thermostate, composable_states=[lambda_state])

And I am getting the following error:

lambda_state.a: 0.5
lambda_state.b: 0.3
Traceback (most recent call last):
  File "/scr/xie1/training_xtb_test/openmm_FEP_lambdastate_REMD.py", line 129, in <module>
    compound_thermodynamic_state = CompoundThermodynamicState(thermostate, composable_states=[lambda_state])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xie1/miniconda3/lib/python3.12/site-packages/openmmtools/states.py", line 2790, in __init__
    self.set_system(self._standard_system, fix_state=True)
  File "/home/xie1/miniconda3/lib/python3.12/site-packages/openmmtools/states.py", line 2843, in set_system
    s.apply_to_system(system)
  File "/home/xie1/miniconda3/lib/python3.12/site-packages/openmmtools/states.py", line 3521, in apply_to_system
    raise self._GLOBAL_PARAMETER_ERROR(err_msg.format(parameter_name))
openmmtools.states.GlobalParameterError: Could not find global parameter a in the system.
peastman commented 2 months ago

Is this specific to TorchForce? What if you change

force = TorchForce('model.pt')

to

force = CustomBondForce('r')

leaving everything else the same, including the calls to addGlobalParameter()? Do you get the same error? If so, the problem isn't related to TorchForce, and you should probably ask at https://github.com/choderalab/openmmtools. On the other hand, if that works and the problem really is related to TorchForce, can you post a complete example with all the files needed to reproduce it?

xiaowei-xie2 commented 2 months ago

Thank you for the reply! I went ahead and tried your suggestion and it seems the problem is specific to TorchForce. I created a simple example to reproduce this behavior below. Also I think I might have found a workaround (inspired by the openmmml package) by adding the following lines.

cv = openmm.CustomCVForce("")
cv.addGlobalParameter("param_a", 1)
cv.addGlobalParameter("param_b", 1)
tempSystem = openmm.System()
tempSystem.addForce(force)
interactingVarNames = []
for idx, force in enumerate(tempSystem.getForces()):
    name = f"allForce{idx+1}"
    cv.addCollectiveVariable(name, copy.deepcopy(force))
    interactingVarNames.append(name)

assert len(interactingVarNames) > 0 

interactingSum = "+".join(interactingVarNames)

cv.setEnergyFunction(
    f"({interactingSum})"
)

system.addForce(cv)

In this openmm_files.tar.gz I have 3 files mmforce.py, torchforce.py and torchforce_workaround.py and their corresponding outputs. You can see that only the force object is changed between the files.

Please let me know if my workaround is correct?

openmm_files.tar.gz

peastman commented 2 months ago

@mikemhenry @ijpulidos can you take a look at this? This error is happening because of an interaction between SWIG and openmmtools.

_get_system_controlled_parameters() tries to find the list of global parameters by looping over all forces and looking for methods called getNumGlobalParameters() and getGlobalParameterName().

for force_index in range(system.getNumForces()):
    force = system.getForce(force_index)
    try:
        n_global_parameters = force.getNumGlobalParameters()
    except AttributeError:
        continue
    for parameter_id in range(n_global_parameters):
        parameter_name = force.getGlobalParameterName(parameter_id)
        if parameter_name in searched_parameters:
            yield force, parameter_name, parameter_id

The problem is that SWIG can only return the correct Python Force subclass from getForce() for built in classes. If the force was defined by a plugin, it just returns an instance of the abstract Force class. That's just referring to the Python wrapper, of course. The C++ object it wraps has the correct class. The TorchForce Python wrapper provides static isinstance() and cast() methods for checking whether something is a wrapped TorchForce and casting it to the correct Python class.

The robust way of getting a list of all global parameters is to call getParameters() on a Context.