Exscientia / physicsml

A package for all physics based/related models
MIT License
40 stars 1 forks source link

Faster neighbor list #21

Closed peastman closed 4 months ago

peastman commented 4 months ago

I find that building neighbor lists is responsible for a significant fraction of the training time. Here is a trace of a couple of training steps for a tensornet model with one interaction layer (running on CPU, since I'm having trouble getting tracing working with the GPU).

Screenshot 2024-03-05 at 9 04 51 AM

The part shown as "enumerate(DataLoader) is almost all spent building neighbor lists.

Screenshot 2024-03-05 at 9 05 50 AM

It uses a voxel algorithm to build neighbor lists. That's an appropriate method for large systems with tens of thousands of atoms. But most molecular datasets are made up of very small molecules with less than 100 atoms. In that case, a simple O(n^2) routine is much faster.

There also are a number of well optimized neighbor list routines written in CUDA, such as in NNPOps and TorchMD-Net. Using one of those could provide a further speedup.

The way the neighbor list routine is used also seems a bit odd to me. primitive_neighbor_list_torch() computes the displacements and distances, which are the quantities needed for computing interactions. But it doesn't return them. Instead it discards them and returns the number of box vectors that need to be added to apply periodic boundary conditions to each pair. That gets passed around through a few layers of code, and eventually passed to compute_lengths_and_vectors() which uses them to recompute the displacements and distances. Returning them directly from the neighbor list instead of the cell shifts would save some duplicate computation.

wardhaddadin1 commented 4 months ago

Hi Peter!

Thanks for this. I've also checked this locally and found the same. Actually, I find that the primitive_neighbor_list_torch is extremely slow even on large systems. Here is a little benchmark:

On a pos = 5 * torch.randn(1000, 3), I get:

If we include PBCs, it's:

Wow, I didnt realise it was so bad! When I chose the neighbour list in physicsml, I just used that one that was in the MACE repo (IIRC). But they've since changed to the matscipy one (probably because it's much faster). But that one is not torchscriptable.

To fix this, we can:

Sorry about this! Will have a PR today to fix the NL computation.

Best, Ward