Closed SamPruden closed 3 years ago
Hi there,
You could potentially apply a filter to remove any zero values. An example I saw is as follows:
D = np.sum((x_centroids[:, None, :] - x_centroids[None, :, :]) ** 2, 2)
keep = D < (4 * sigma + np.sqrt(3) * eps) ** 2
ranges_ij = from_matrix(x_ranges, x_ranges, keep)
x_ = x / sigma # N.B.: x is a **sorted** list of points
x_i, x_j = LazyTensor(x_[:, None, :]), LazyTensor(x_[None, :, :])
K_xx = (-((x_i - x_j) ** 2).sum(2) / 2).exp() # Symbolic (N,N) Gaussian kernel matrix
K_xx.ranges = ranges_ij # block-sparsity pattern
print(K_xx)
In your case you would have a filter which would be: keep = D > 0
Cheers, Harry
You could potentially apply a filter to remove any zero values. An example I saw is as follows:
Ah, thanks! I had seen that previously, but I interpreted block sparsity to mean that the keep
mask would operate on the per block level, rather than per value. I'll have to take another look.
I just took a look at making this work and couldn't easily see how to use that filter mechanism with individual values, but it turns out not to matter anyway because of this.
As of today, KeOps does not support backpropagation through the Min_Reduction reduction. Adding this feature to LazyTensors is on the cards for future releases... But until then, you may want to consider extracting the relevant integer indices with a '.argmin()', '.argmax()' or '.argKmin()' reduction before using PyTorch advanced indexing to create a fully-differentiable tensor containing the relevant 'minimal' values.
I suppose that I'm doing the right thing for now then.
Hi @SamPruden , @harry1576 ,
Thanks for your interest in the library!
Indeed, I believe that the best option is still to use an argKmin(k = 2)
reduction to extract the index of the nearest neighbour, and then compute the squared distances explicitly as sqdists = ((points - points[indices]) ** 2).sum(-1)
in PyTorch.
This shouldn't be inefficient, since the intermediate result points - points[indices]
is required for the backward pass anyway.
For reference, a way more complex (but slightly faster in the forward pass) implementation could be:
import torch
from pykeops.torch import LazyTensor
class NNDistances(torch.autograd.Function):
"""See https://pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html for details on this syntax."""
@staticmethod
def forward(ctx, x):
# Encoding as KeOps LazyTensors:
x_i = LazyTensor(x[:, None, :])
x_j = LazyTensor(x[None, :, :])
# Matrix of squared distances:
D_ij = ((x_i - x_j) ** 2).sum(-1)
# Distance to self and the nearest neighbors:
distances, indices = D_ij.Kmin_argKmin(2, dim=1)
# Keep the second columns:
distances = distances[:, 1]
indices = indices[:, 1]
ctx.save_for_backward(x, indices)
return distances
@staticmethod
def backward(ctx, grad_output):
x, indices = ctx.saved_tensors
g = 2 * grad_output[:, None] * (x - x[indices])
return g - torch.zeros_like(g).scatter_add_(0, indices[:, None].repeat(1, g.shape[-1]), g)
nndistances = NNDistances.apply
N, D = 100, 3
points = torch.randn(N, D).cuda()
points.requires_grad = True
torch.autograd.gradcheck(nndistances, points, eps=1e-3)
I wouldn't recommend using it though, since it's not very readable and maybe slower in the backward pass.
In any case: after the release of our new compilation engine in the v1.6, we will soon be back to adding new features and addressing compatibility issues with PyTorch/NumPy. Hopefully, these workarounds will become obsolete in the near future :-)
Best regards, Jean
Thanks for the comprehensive response @jeanfeydy! I'm glad to know that I'm not doing anything too silly with my existing code.
I have a collection of 3D points, represented as
[batch, n, 3]
PyTorch tensors. I want to obtain a[batch, n, squared_dist]
tensor representing the squared distance to each point's nearest neighbour. I will need to be able to autodiff through this.The standard approach to nearest neighbour problems of this sort is to create a squared distance matrix. However, this has the issue that the diagonal is 0, so a simple use of
min
will always return 0s. I'm looking for some concept of "minimum excluding the diagonal", or "secondmost minimum".At the moment, I do this by using
argKmin
withk = 2
, then back in PyTorch land taking the second index for each point, then using those "secondmost minimum" indices to lookup the points and calculate squared distances. This is obviously fairly inefficient.Another approach might be to add large values to the diagonal, but I don't know of an efficient way to do this symbolically.
Is there a good way to achieve this fully in keops?