facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.49k stars 1.28k forks source link

Periodic Ball Query nearest-neighbour search #1623

Open NicklasOsterbacka opened 10 months ago

NicklasOsterbacka commented 10 months ago

🚀 Feature

Nearest-neighbor search with Ball Query under periodic boundary conditions. #1012 is more open-ended, but related.

Motivation

Force fields are commonplace in computational chemistry and physics, where they are used to drive molecular dynamics- and Monte Carlo-based simulations of atomic systems. A key approximation is the truncation of interaction distances, typically implemented as considering only interactions with the k nearest neighbours of each atom. Ball Query as already implemented in PyTorch3D can be used for this, and KNN with some post-processing, since the atomic positions are analogous to point clouds.

For solids and liquids periodic boundary conditions are utilised to make simulations tractable. Instead of considering the huge number of atoms inside of e.g. a piece of metal, one considers a smaller number of atoms inside of a periodic box. Ball Query and KNN as currently implemented do not support this, however.

Machine learning force fields, which are becoming more commonplace, would particularly benefit from having a fully PyTorch-compatible solution.

Pitch

A version of pytorch3d.ops.ball_query with support for periodic boundary conditions.

Suggested signature: pytorch3d.ops.ball_query_pbc(p: torch.Tensor, cell: torch.Tensor, lengths: Optional[torch.Tensor] = None, K: int = 100, radius: float = 5.0, neighbor_self: bool = False, return_nn: bool = True)

There are three important differences to the input signature of the non-periodic Ball Query: only one point cloud is considered, the cell containing the point cloud is given as an additional 3-by-3 input, and a boolean controls whether or not each point should consider itself to be its neighbour.

Suggested return structure: deltas - Tensor of shape (N, P, K, D), for the N D-dimensional point clouds each containing up to P points given as input with up to K neighbours. Contains vectors pointing from each central point to its neighbors. idx: LongTensor of shape (N, P, K) giving the indices of the S neighbours in for each point in the point cloud.

Position differences, i.e., p1-p2 for each central point p1 and its neighbours p2 are quantities of interest rather than the distance between the points for many force fields.

An important edge case is when the point cloud is small enough such that the radius would envelop several repetitions of the same point. In such cases, all of them should be returned and considered neighbours. The appended image illustrates this. The central, filled circle is the sole point in the point cloud contained within the solid box, while the dashed boxes and circles therein represent its periodic images. The big circle is the boundary of the neighbour search. To the top, bottom, left, and right two repetitions end up inside. All of these should be counted and returned. (The same index for each, repeated, but with deltas accounting for periodicity.)

periodicity

NicklasOsterbacka commented 10 months ago

Additionally, gradients would be important for machine learning force fields as the forces acting upon each atom is the gradient of the model output w.r.t. the atomic positions. Forgot to mention that in the issue!

moradpur commented 4 weeks ago

That would be an excellent feature for machine learning of crystals.