Closed RaulPPelaez closed 10 months ago
I'm not sure this is the right fix. I think the real problem is in OpenMM-Torch, and this just masks the symptoms.
TorchForce used to store the name of the file containing the model. When you created a Context (and therefore a TorchForceImpl), it would load the file to create a torch::jit::Module
. That meant every Context had its own copy of the model, and there was no possibility of interaction between them.
https://github.com/openmm/openmm-torch/pull/97 changed it to make TorchForce directly store the torch::jit:Module
. The result is that all Contexts share the same model. This creates the possibility of one Context influencing another one. The bug seen here is one manifestation of that. It could also happen in much more subtle ways that lead to incorrect results.
I think the correct solution is to make TorchForceImpl clone the model. That will ensure that every Context again has its own independent copy.
I'm not sure this is the right fix. I think the real problem is in OpenMM-Torch, and this just masks the symptoms.
I think it will require the torch model code to handle the devices as done by @RaulPPelaez here. How would it work for a simple pure pytorch model, e.g.:
import torch as pt
from openmm import Context, LocalEnergyMinimizer, Platform, System, VerletIntegrator
from openmmtorch import TorchForce
scale = 1.0e10
platform = "CUDA"
device = "cuda"
class Model(pt.nn.Module):
def __init__(self, scale, device):
super().__init__()
self.device=device
self.scale = scale
self.r0 = pt.tensor([0.0,0.0,0.0], device=device)
def forward(self, positions):
positions=positions.to(self.device) # <- without this line it will not work
return self.scale * pt.sum(positions - self.r0)**2
model = pt.jit.script(Model(scale, device))
force = TorchForce(model)
system = System()
system.addForce(force)
for _ in range(2):
system.addParticle(1)
platform = Platform.getPlatformByName(platform)
context = Context(system, VerletIntegrator(1), platform)
context.setPositions([[0, 0, 0], [1, 0, 0]])
LocalEnergyMinimizer.minimize(context)
Typically you have to define the device for some tensors in the constructor. In the forward method you then expect the positions to be on the same device. If they then come on CPU instead of CUDA because LocalEnergyMinimizer has detected the forces are large you will need to have code in the forward method which copies the tensors to all be on the same device. Or is there a different way to do this without the .to(device)
lines?
@RaulPPelaez, I think @peastman is right. Each context should have a separate copy of Torch module. So, it can be initialized once on a specific device and never changes. This will ensure the isolation of the contexts and the NNPops don't need to handle the device changes.
@sef43 the module shouldn't have explicit device assignments. Rather you create parameters and/or buffers, so PyTorch can move them to the right device. OpenMM-Torch already uses that mechanism (https://github.com/openmm/openmm-torch/blob/e9f2ae24f00138740ee6683ea4ccd476c268c183/platforms/cuda/src/CudaTorchKernels.cpp#L78).
I believe I am missing something about the issue @peastman is describing. Instead of modifying NNPOps I can go ahead and make it so that TorchForceImpl in OpenMM-Torch loads the module anew from a file each time. For instance, I can change the getModule function in TorchForce so that each time it returns a new instance of the module:
const torch::jit::Module TorchForce::getModule() const {
std::stringstream output_stream;
this->module.save(output_stream);
return torch::jit::load(output_stream);
}
This way TorchForceImpl::initialize gets a just loaded module each time:
void TorchForceImpl::initialize(ContextImpl& context) {
auto module = owner.getModule();
// Create the kernel.
kernel = context.getPlatform().createKernel(CalcTorchForceKernel::Name(), context);
kernel.getAs<CalcTorchForceKernel>().initialize(context.getSystem(), owner, module);
}
As far as I understand this is equivalent to the behavior of TorchForceImpl before https://github.com/openmm/openmm-torch/pull/97 . This however results in the same error as in the original post.
EDIT: I made a mistake, the fix above does actually fix the error and it makes sense to me why.
I opened https://github.com/openmm/openmm-torch/pull/116 with the fix suggested by @peastman. The original reproducer passes on my machine using that instead of this PR.
Hence, while this PR does make SymmetryFunctions robust to devices changing, I am not sure if it is worth merging it.
Solves #112.
AFAIK, the error in #112 consists of OpenMM changing the location of just the positions, which ends up with NNPOps down the line being fed tensors in two different devices.
The original reproducer by @raimis:
Is solved by simply sending the positions to the same device as the tensor with the atomic numbers (which always stays on the same device), I did this by modifying OptimizedTorchANI:
In this PR, I also restructured SymmetryFunctions a bit. Instead of creating an implementation and storing it for the duration of the execution, now this class holds a map from devices to implementations, creating/fetching the necessary one according to where the positions are stored. I did this because the positions (and only the positions) suddenly changing devices leaves us with an ambiguous decision. Do we:
In the first case, we simply move positions back to the device when required. In the second case, we must ensure every component can handle inputs with changing device. OTOH, this makes me think: OpenMM suddenly changing the device of the positions without correctly informing NNPOps (perhaps by calling model.to(device)?) sounds to me like a bug in either OpenMM or OpenMM-Torch.
Finally I also applied the formatter.