openmm / openmm-ml

High level API for using machine learning models in OpenMM simulations
Other
75 stars 25 forks source link

Accessing the force as an attribute in a MLPotentialImpl #76

Closed annamherz closed 2 months ago

annamherz commented 2 months ago

Hello! I need to access just the TorchForce created for the system with the MLPotential. I was wondering if it would be possible to assign the force to an attribute so it can be accessed separately? So for example in the mace potential, something like this:

        # Create the TorchForce and add it to the System.
        force = openmmtorch.TorchForce(module)
        force.setForceGroup(forceGroup)
        force.setUsesPeriodicBoundaryConditions(isPeriodic)
        self.force = force
        system.addForce(force)

In this way then I could get the force for my hypothetical system using:

        mace_system = deepcopy(system) # system being an openmm system created earlier

        macepotimpl = models.macepotential.MACEPotentialImpl(name="mace-off23-small", modelPath="")
        macepotimpl.addForces(
                        modeller.topology, # the modeller topology used to create the system above
                        mace_system,
                        ml_atoms, # indices of the ligand
                        0,
                        periodic = True,
        )

        ml_force = deepcopy(macepotimpl.force)

Alternatively if there is some other way to get the force from this that I missed?

Thank you so much!

peastman commented 2 months ago

You can call getForces() on the System to get all the forces it contains. You can do something like this:

force = [f for f in system.getForces() if isinstance(f, TorchForce)][0]
annamherz commented 2 months ago

I see thank you! Actually for some reason in the 'getForces()' there isn't a TorchForce after it is added to the system, but I can find it using:

        torch_force = [f for f in mace_system.getForces() if f.getName() == "TorchForce"][0] 
        print(torch_force)
        print(type(torch_force))

where the type of it after it was added to the system seems to be <class 'openmm.openmm.Force'> instead of <class 'openmmtorch.TorchForce'> which it was before being added to the system. I'm not sure if this is important? But either way I think it will now work for what I need the force for after, thank you!

peastman commented 2 months ago

Oh right, sorry about that. It's a quirk of the SWIG generated wrappers. getForces() only knows about the Force classes in the main OpenMM package. Those ones get represented with the correct subclasses. Others just get represented with generic Force objects. You can use TorchForce.isinstance(f) to check if f is a TorchForce, and TorchForce.cast(f) to cast it to the correct class.

annamherz commented 2 months ago

That makes sense, thank you :blush: