openmm / openmm-torch

OpenMM plugin to define forces with neural networks
185 stars 24 forks source link

Positions from Torch GPU Tensors #12

Open Olllom opened 4 years ago

Olllom commented 4 years ago

Hi @peastman,

We were thinking about ways to assign context positions directly from torch GPU tensors in order to speed up batch processing of force and energy calculations for small systems. The kernels in this plugin, like copyInputs, seem to already provide a blueprint and at first sight it looks pretty straightforward to expose something like setContextPositionsFromTensor(Context& context, torch::Tensor& positions) to the API. I believe that this would be a worthwhile extension that could be useful to others. (Of course, it is not directly related to the main purpose of this plugin but I am sure that you have already thought about this.) Do you see any potential caveats/problems that I am currently overlooking?

Tagging @yaoyic

Thanks, Andreas

peastman commented 4 years ago

You have to be a bit careful about assigning positions. Take a look at https://github.com/openmm/openmm/blob/master/platforms/cuda/src/CudaKernels.cpp#L206-L255, which gets executed when you call setPositions() on a Context. There's a lot more happening there than just copying over one array, and you could easily put things into an inconsistent state that leads to wrong results.

jchodera commented 3 years ago

@peastman: Has this been addressed?

peastman commented 3 years ago

We haven't added the suggested feature. Whether there would be any benefit to it is an open question. You can always set the positions with

context.setPositions(positions.detach().numpy())

The question is whether this approach could reduce overhead by a meaningful amount. Someone would have to try implementing it and see.

Olllom commented 3 years ago

In my opinion this could be really important for a lot of ML+MD applications and drastically reduce CPU<->GPU transfer in applications that require only energy/force evaluations of classical potentials but no integration.

I have not been able to dedicate time to this and will probably not do so in the near future. One thing that scares me a bit is interfacing OpenMM's SWIG wrappers with torch's pybind11 wrappers. A workaround to this could be to just operate with the device pointers rather than the tensors directly but this seems less elegant.

That said, I could see that this feature would add a lot of flexibility to what can be done efficiently with OpenMM. Many projects in our group could greatly benefit from a tighter integration between OpenMM and pyTorch and I would be extremely grateful if anybody else takes a stab.

@franknoe

peastman commented 3 years ago

I think if you call .data_ptr() on the tensor that will give you the device pointer for it. So we could implement it with a Python function that extracts the pointer from the tensor, then passes it down to the C++ function. That way the C++ code wouldn't have to deal with any PyTorch objects directly, and users would still have a clean API.