openmm / NNPOps

High-performance operations for neural network potentials
Other
79 stars 17 forks source link

Make OptimizedTorchANI robust to changes to device between calls. #113

Closed RaulPPelaez closed 10 months ago

RaulPPelaez commented 11 months ago

Solves #112.

AFAIK, the error in #112 consists of OpenMM changing the location of just the positions, which ends up with NNPOps down the line being fed tensors in two different devices.

The original reproducer by @raimis:

import torch as pt
from NNPOps import OptimizedTorchANI
from openmm import Context, LocalEnergyMinimizer, Platform, System, VerletIntegrator
from openmmtorch import TorchForce
from torchani.models import ANI2x

scale = 1e10
platform = "CUDA"

class Model(pt.nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = scale
        self.species = pt.tensor([1, 1]).unsqueeze(0)
        self.model = ANI2x(periodic_table_index=True)
        self.model = OptimizedTorchANI(self.model, self.species)
    def forward(self, positions):
        positions = positions.unsqueeze(0).to(pt.float32)
        return self.scale * self.model.forward((self.species, positions))[1]

force = TorchForce(pt.jit.script(Model(scale)))

system = System()
system.addForce(force)
for _ in range(2):
    system.addParticle(1)

platform = Platform.getPlatformByName(platform)
context = Context(system, VerletIntegrator(1), platform)

context.setPositions([[0, 0, 0], [1, 0, 0]])
LocalEnergyMinimizer.minimize(context)

Is solved by simply sending the positions to the same device as the tensor with the atomic numbers (which always stays on the same device), I did this by modifying OptimizedTorchANI:

    def forward(self, species_coordinates: Tuple[Tensor, Tensor],
                 cell: Optional[Tensor] = None,
                 pbc: Optional[Tensor] = None) -> SpeciesEnergies:

         species_coordinates = self.species_converter(species_coordinates)
+        species_coordinates = (
+            species_coordinates[0],
+            species_coordinates[1].to(species_coordinates[0].device),
+        )
         species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
         species_energies = self.neural_networks(species_aevs)
         species_energies = self.energy_shifter(species_energies)

        return species_energies

In this PR, I also restructured SymmetryFunctions a bit. Instead of creating an implementation and storing it for the duration of the execution, now this class holds a map from devices to implementations, creating/fetching the necessary one according to where the positions are stored. I did this because the positions (and only the positions) suddenly changing devices leaves us with an ambiguous decision. Do we:

  1. Respect the original device the module was created with (like my fix above does)
  2. Respect the device of "positions" by doing all computations in its device In other words, what should OptimizedTorchANI do when this assertion fails?:
    def forward(
        self,
        species_coordinates: Tuple[Tensor, Tensor],
        cell: Optional[Tensor] = None,
        pbc: Optional[Tensor] = None,
    ) -> SpeciesEnergies:
        assert species_coordinates[0].device == species_coordinates[1].device

    In the first case, we simply move positions back to the device when required. In the second case, we must ensure every component can handle inputs with changing device. OTOH, this makes me think: OpenMM suddenly changing the device of the positions without correctly informing NNPOps (perhaps by calling model.to(device)?) sounds to me like a bug in either OpenMM or OpenMM-Torch.

Finally I also applied the formatter.

peastman commented 11 months ago

I'm not sure this is the right fix. I think the real problem is in OpenMM-Torch, and this just masks the symptoms.

TorchForce used to store the name of the file containing the model. When you created a Context (and therefore a TorchForceImpl), it would load the file to create a torch::jit::Module. That meant every Context had its own copy of the model, and there was no possibility of interaction between them.

https://github.com/openmm/openmm-torch/pull/97 changed it to make TorchForce directly store the torch::jit:Module. The result is that all Contexts share the same model. This creates the possibility of one Context influencing another one. The bug seen here is one manifestation of that. It could also happen in much more subtle ways that lead to incorrect results.

I think the correct solution is to make TorchForceImpl clone the model. That will ensure that every Context again has its own independent copy.

sef43 commented 11 months ago

I'm not sure this is the right fix. I think the real problem is in OpenMM-Torch, and this just masks the symptoms.

I think it will require the torch model code to handle the devices as done by @RaulPPelaez here. How would it work for a simple pure pytorch model, e.g.:

import torch as pt
from openmm import Context, LocalEnergyMinimizer, Platform, System, VerletIntegrator
from openmmtorch import TorchForce

scale = 1.0e10
platform = "CUDA"
device = "cuda"

class Model(pt.nn.Module):
    def __init__(self, scale, device):
        super().__init__()
        self.device=device
        self.scale = scale
        self.r0 = pt.tensor([0.0,0.0,0.0], device=device)
    def forward(self, positions):
        positions=positions.to(self.device) # <- without this line it will not work 
        return self.scale * pt.sum(positions - self.r0)**2

model = pt.jit.script(Model(scale, device))
force = TorchForce(model)

system = System()
system.addForce(force)
for _ in range(2):
    system.addParticle(1)

platform = Platform.getPlatformByName(platform)
context = Context(system, VerletIntegrator(1), platform)

context.setPositions([[0, 0, 0], [1, 0, 0]])
LocalEnergyMinimizer.minimize(context)

Typically you have to define the device for some tensors in the constructor. In the forward method you then expect the positions to be on the same device. If they then come on CPU instead of CUDA because LocalEnergyMinimizer has detected the forces are large you will need to have code in the forward method which copies the tensors to all be on the same device. Or is there a different way to do this without the .to(device) lines?

raimis commented 11 months ago

@RaulPPelaez, I think @peastman is right. Each context should have a separate copy of Torch module. So, it can be initialized once on a specific device and never changes. This will ensure the isolation of the contexts and the NNPops don't need to handle the device changes.

raimis commented 11 months ago

@sef43 the module shouldn't have explicit device assignments. Rather you create parameters and/or buffers, so PyTorch can move them to the right device. OpenMM-Torch already uses that mechanism (https://github.com/openmm/openmm-torch/blob/e9f2ae24f00138740ee6683ea4ccd476c268c183/platforms/cuda/src/CudaTorchKernels.cpp#L78).

RaulPPelaez commented 10 months ago

I believe I am missing something about the issue @peastman is describing. Instead of modifying NNPOps I can go ahead and make it so that TorchForceImpl in OpenMM-Torch loads the module anew from a file each time. For instance, I can change the getModule function in TorchForce so that each time it returns a new instance of the module:

const torch::jit::Module TorchForce::getModule() const {
  std::stringstream output_stream;
  this->module.save(output_stream);
  return torch::jit::load(output_stream);
}

This way TorchForceImpl::initialize gets a just loaded module each time:

void TorchForceImpl::initialize(ContextImpl& context) {
    auto module = owner.getModule();
    // Create the kernel.
    kernel = context.getPlatform().createKernel(CalcTorchForceKernel::Name(), context);
    kernel.getAs<CalcTorchForceKernel>().initialize(context.getSystem(), owner, module);
}

As far as I understand this is equivalent to the behavior of TorchForceImpl before https://github.com/openmm/openmm-torch/pull/97 . This however results in the same error as in the original post.

EDIT: I made a mistake, the fix above does actually fix the error and it makes sense to me why.

RaulPPelaez commented 10 months ago

I opened https://github.com/openmm/openmm-torch/pull/116 with the fix suggested by @peastman. The original reproducer passes on my machine using that instead of this PR.

Hence, while this PR does make SymmetryFunctions robust to devices changing, I am not sure if it is worth merging it.