getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.05k stars 64 forks source link

[Question] What's the best way to do a minimum neighbour distance excluding self? #194

Closed SamPruden closed 3 years ago

SamPruden commented 3 years ago

I have a collection of 3D points, represented as [batch, n, 3] PyTorch tensors. I want to obtain a [batch, n, squared_dist] tensor representing the squared distance to each point's nearest neighbour. I will need to be able to autodiff through this.

The standard approach to nearest neighbour problems of this sort is to create a squared distance matrix. However, this has the issue that the diagonal is 0, so a simple use of min will always return 0s. I'm looking for some concept of "minimum excluding the diagonal", or "secondmost minimum".

At the moment, I do this by using argKmin with k = 2, then back in PyTorch land taking the second index for each point, then using those "secondmost minimum" indices to lookup the points and calculate squared distances. This is obviously fairly inefficient.

Another approach might be to add large values to the diagonal, but I don't know of an efficient way to do this symbolically.

Is there a good way to achieve this fully in keops?

harrydobbs commented 3 years ago

Hi there,

You could potentially apply a filter to remove any zero values. An example I saw is as follows:

D = np.sum((x_centroids[:, None, :] - x_centroids[None, :, :]) ** 2, 2)
keep = D < (4 * sigma + np.sqrt(3) * eps) ** 2

ranges_ij = from_matrix(x_ranges, x_ranges, keep)
x_ = x / sigma  # N.B.: x is a **sorted** list of points
x_i, x_j = LazyTensor(x_[:, None, :]), LazyTensor(x_[None, :, :])
K_xx = (-((x_i - x_j) ** 2).sum(2) / 2).exp()  # Symbolic (N,N) Gaussian kernel matrix

K_xx.ranges = ranges_ij  # block-sparsity pattern
print(K_xx)

In your case you would have a filter which would be: keep = D > 0

Cheers, Harry

SamPruden commented 3 years ago

You could potentially apply a filter to remove any zero values. An example I saw is as follows:

Ah, thanks! I had seen that previously, but I interpreted block sparsity to mean that the keep mask would operate on the per block level, rather than per value. I'll have to take another look.

SamPruden commented 3 years ago

I just took a look at making this work and couldn't easily see how to use that filter mechanism with individual values, but it turns out not to matter anyway because of this.

As of today, KeOps does not support backpropagation through the Min_Reduction reduction. Adding this feature to LazyTensors is on the cards for future releases... But until then, you may want to consider extracting the relevant integer indices with a '.argmin()', '.argmax()' or '.argKmin()' reduction before using PyTorch advanced indexing to create a fully-differentiable tensor containing the relevant 'minimal' values.

I suppose that I'm doing the right thing for now then.

jeanfeydy commented 3 years ago

Hi @SamPruden , @harry1576 ,

Thanks for your interest in the library! Indeed, I believe that the best option is still to use an argKmin(k = 2) reduction to extract the index of the nearest neighbour, and then compute the squared distances explicitly as sqdists = ((points - points[indices]) ** 2).sum(-1) in PyTorch. This shouldn't be inefficient, since the intermediate result points - points[indices] is required for the backward pass anyway.

For reference, a way more complex (but slightly faster in the forward pass) implementation could be:

import torch
from pykeops.torch import LazyTensor

class NNDistances(torch.autograd.Function):
    """See https://pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html for details on this syntax."""

    @staticmethod
    def forward(ctx, x):
        # Encoding as KeOps LazyTensors:
        x_i = LazyTensor(x[:, None, :])
        x_j = LazyTensor(x[None, :, :])

        # Matrix of squared distances:
        D_ij = ((x_i - x_j) ** 2).sum(-1)

        # Distance to self and the nearest neighbors:
        distances, indices = D_ij.Kmin_argKmin(2, dim=1)

        # Keep the second columns:
        distances = distances[:, 1]
        indices = indices[:, 1]

        ctx.save_for_backward(x, indices)
        return distances

    @staticmethod
    def backward(ctx, grad_output):
        x, indices = ctx.saved_tensors

        g = 2 * grad_output[:, None] * (x - x[indices])

        return g - torch.zeros_like(g).scatter_add_(0, indices[:, None].repeat(1, g.shape[-1]), g)

nndistances = NNDistances.apply

N, D = 100, 3
points = torch.randn(N, D).cuda()
points.requires_grad = True

torch.autograd.gradcheck(nndistances, points, eps=1e-3)

I wouldn't recommend using it though, since it's not very readable and maybe slower in the backward pass.

In any case: after the release of our new compilation engine in the v1.6, we will soon be back to adding new features and addressing compatibility issues with PyTorch/NumPy. Hopefully, these workarounds will become obsolete in the near future :-)

Best regards, Jean

SamPruden commented 3 years ago

Thanks for the comprehensive response @jeanfeydy! I'm glad to know that I'm not doing anything too silly with my existing code.