facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.84k stars 1.32k forks source link

Custom distance function in KNN #1012

Open ParthGoyal1508 opened 2 years ago

ParthGoyal1508 commented 2 years ago

🚀 Feature

Use of custom distance function while calculating k nearest neighbour.

NOTE: Please look at the existing list of Issues tagged with the label 'enhancement`. Only open a new issue if you do not see your feature request there.

Motivation

It would be great if we can use custom distance function while calculating k nearest neighbour instead of euclidean distance.

Pitch

In Molecular dynamics simulations, the simulations are carried out in a fixed size box so while calculating k nearest neighbour we have to use periodic boundary condition and minimum image convention so it will be convenient if we can pass custom distance function for calculating the distance.

For example:

Box boundaries (3,3,3)

Point 1 coordiante: (0.2,0.2,0.2) Point 2 coordinate: (2.7,2.7,2.7)

Euclidean distance: 4.33 Pbc distance: 0.866

NOTE: we only consider adding new features if they are useful for many users.

bottler commented 2 years ago

The distance calculation is hard coded in CUDA and C++ code. If we allowed the user to specify a distance function as a function in python, then the calculation would be much slower, so I don't think that's a good idea.

However it would be quite easy to make a new version of the current implementation which includes periodic boundary conditions, taking the box boundaries as an extra input. Some questions:

ParthGoyal1508 commented 2 years ago

Hey! Thanks for the answer. I will try to make a new version of the current implementation as suggested above.

Thanks!

bottler commented 2 years ago

It sounds like adapting our calculation would be the right thing to do.

The gradient would be if you need the derivative of the dists output of knn_points. E.g. you call torch.autograd.grad or backward on something which has come from the dists. Internally, that's when the backward function of class _knn_points gets used.

gkioxari commented 2 years ago

This is a great point! I have also recently hacked the KNN C++/CUDA implementation to use other distance functions, beyond L2. In my case it was a simple L1, but we should consider adding some standardized distance functions. So I marked this as enhancement and we can try to get to it this half.