Closed peastman closed 4 months ago
Hi Peter!
Thanks for this. I've also checked this locally and found the same. Actually, I find that the primitive_neighbor_list_torch
is extremely slow even on large systems. Here is a little benchmark:
On a pos = 5 * torch.randn(1000, 3)
, I get:
primitive_neighbor_list_torch
from physicsml
=> time ~0.39s (mean of 10 iters)from torch_nl import compute_neighborlist
, time ~ 0.07s (mean of 10)from matscipy.neighbours import neighbour_list
, time ~ 0.05s (mean of 10)If we include PBCs, it's:
primitive_neighbor_list_torch
from physicsml
=> time ~0.4s (mean of 10 iters)from torch_nl import compute_neighborlist
, time ~ 0.09s (mean of 10)from matscipy.neighbours import neighbour_list
, time ~ 0.09s (mean of 10)Wow, I didnt realise it was so bad! When I chose the neighbour list in physicsml, I just used that one that was in the MACE repo (IIRC). But they've since changed to the matscipy
one (probably because it's much faster). But that one is not torchscriptable.
To fix this, we can:
torch-nl
implementation when we have pbcs. (From here: https://github.com/felixmusil/torch_nl)Sorry about this! Will have a PR today to fix the NL computation.
Best, Ward
I find that building neighbor lists is responsible for a significant fraction of the training time. Here is a trace of a couple of training steps for a tensornet model with one interaction layer (running on CPU, since I'm having trouble getting tracing working with the GPU).
The part shown as "enumerate(DataLoader) is almost all spent building neighbor lists.
It uses a voxel algorithm to build neighbor lists. That's an appropriate method for large systems with tens of thousands of atoms. But most molecular datasets are made up of very small molecules with less than 100 atoms. In that case, a simple O(n^2) routine is much faster.
There also are a number of well optimized neighbor list routines written in CUDA, such as in NNPOps and TorchMD-Net. Using one of those could provide a further speedup.
The way the neighbor list routine is used also seems a bit odd to me.
primitive_neighbor_list_torch()
computes the displacements and distances, which are the quantities needed for computing interactions. But it doesn't return them. Instead it discards them and returns the number of box vectors that need to be added to apply periodic boundary conditions to each pair. That gets passed around through a few layers of code, and eventually passed tocompute_lengths_and_vectors()
which uses them to recompute the displacements and distances. Returning them directly from the neighbor list instead of the cell shifts would save some duplicate computation.