getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.05k stars 64 forks source link

How to convert a LazyTensor into a dense torch.Tensor before reduction? #88

Closed srikanthmalla closed 4 years ago

srikanthmalla commented 4 years ago

Hi @jeanfeydy , I am new to this library. I am running your example of clustering (https://www.kernel-operations.io/keops/_auto_tutorials/kmeans/plot_kmeans_torch.html).

def KMeans(x, K=10, Niter=10, verbose=True):
     ......
    for i in range(Niter):
        ........
        D_ij = ((x_i - c_j) ** 2).sum(-1)  # (Npoints, Nclusters) symbolic matrix (lazy tensor)
        cl = D_ij.argmin(dim=1).long().view(-1)  # Points -> Nearest cluster (pytorch tensor)
        ......
    return cl, c

Is there a way to access D_ij result as a pytorch tensor (before reduction with argmin or min) and return similar to cl ? Please let me know.

Thank you, Srikanth

jeanfeydy commented 4 years ago

Hi @srikanthmalla ,

You're welcome! As detailed on our frontpage KeOps is all about supporting symbolic arrays: matrices that are best encoded using a symbolic (but arbitrarily complex) formula and small-ish data arrays. Just like sparse matrices are ideally suited to deal with stiffness matrices or graph convolutions, KeOps LazyTensors are optimized to work with kernels and generalized distance matrices.

As detailed in our roadmap, we intend to provide an explicit .dense() or .array() conversion routine in future releases. It could be used to work easily with dense sub-matrices, as requested e.g. by the Nyström quadrature method.

Nevertheless, casting a KeOps LazyTensor such as D_ij into a dense torch.Tensor is usually not advised. Most often, we reach optimal performance when math operations (e.g. non-linearities) are applied directly on the KeOps LazyTensor, prior to the reduction call that outputs a genuine torch.Tensor. This is true for both memory and speed: see our benchmarks and introduction to GPU programming for some insight on the question.

In your specific setting, we may be able to work out an efficient implementation using the right mix of KeOps and PyTorch operations. May I ask why you intended to convert D_ij into a torch.Tensor prior to a reduction call?

Best regards, Jean

srikanthmalla commented 4 years ago

Hi @jeanfeydy , Thank you for the prompt reply. I am trying to re-implement deep cluster-v2 with keops, to use with my project (which uses pytorch-geometric):

deep cluster-v2 original implementation: https://github.com/facebookresearch/swav/blob/d4970c83791c8bd9928ec4efcc25d4fd788c405f/main_deepclusterv2.py#L343 paper: https://arxiv.org/pdf/1807.05520.pdf

  1. I am currently computing the D_ij to dense tensor using dot product by multiplication with identity matrix. This i do at the end of the function when returning the torch tensors.
ones = torch.eye(K).to(args.device)
scores = (D_ij@ones).float() #this operation takes lot of time like 8 sec, otherwise the whole function takes 0.08 sec
... # normalize scores if they are based on euclidean
return cl, c, scores
  1. Also ideally I want to change the distance metric to dot product (according to the paper). But I get empty clusters when using dot product distance, I have to assign non-empty cluster centroid with some perturbation to empty clusters (according to the paper, to overcome this issue).
if args.dist_metric=="euclidean":
        D_ij = ((x_i - c_j) ** 2).sum(-1)  # (Npoints, Nclusters) symbolic matrix of squared distances
if args.dist_metric=="cosine":
        D_ij = x_n@c.t() # dot product by mat-mul, x_n and c.t are normalized 

Any thoughts or suggestions will be very helpful.

Thank you, Srikanth

jeanfeydy commented 4 years ago

Hi @srikanthmalla ,

You're welcome! Some answers to your questions:

  1. In the link above, I only see a (surprisingly complex and heavy) implementation of the K-means algorithm. I must be missing something: could you please tell me at which line you need to compute the full matrix of distances between the centroids and the samples?

  2. KeOps allows you to work with dot products easily: the key here is to use the dot product operator (x_i | y_j). A typical implementation of the K-means algorithm with the cosine similarity "metric" reads:

import torch
from matplotlib import pyplot as plt

from pykeops.torch import LazyTensor

def KMeans(x, K=10, Niter=10):
    N, D = x.shape  # Number of samples, dimension of the ambient space

    # K-means loop:
    # - x  is the point cloud,
    # - cl is the vector of class labels
    # - c  is the cloud of normalizes cluster centroids
    c = x[:K, :].clone()  # Simplistic random initialization

    # Normalize the centroids, since we work with the cosine similarity "metric":
    c = torch.nn.functional.normalize(c, dim=1, p=2)

    # Symbolic representation:
    x_i = LazyTensor(x[:, None, :])  # (Npoints, 1, D)
    c_j = LazyTensor(c[None, :, :])  # (1, Nclusters, D)

    for i in range(Niter):

        # E step: assign points to a cluster that maximizes correlation --------
        D_ij = (x_i | c_j)  # (Npoints, Nclusters) symbolic Gram matrix
        cl = D_ij.argmax(dim=1).long().view(-1)  # Points -> Nearest cluster

        # M step: update the centroids to the normalized cluster average: ------
        # Average centroid per cluster:
        c.zero_()
        c.scatter_add_(0, cl[:,None].repeat(1, D), x)

        # Normalize the centroids, in place:
        c[:] = torch.nn.functional.normalize(c, dim=1, p=2)

    return cl, c

You can use it as follows:

N, D, K = 10000, 2, 50
x = .4 * torch.randn(N, D).cuda() + .3

cl, c = KMeans(x, K)
plt.figure(figsize=(8,8))
plt.scatter(x[:, 0].cpu(), x[:, 1].cpu(), c=cl.cpu(), s= 30000 / len(x), cmap="tab10")
plt.scatter(c[:, 0].cpu(), c[:, 1].cpu(), c='black', s=50, alpha=.8)
plt.axis([-1.2,1.2,-1.2,1.2]) ; plt.tight_layout() ; plt.show()

kmeans_cosine

If needed, you can detect empty clusters with the bincount routine:

# Number of points per cluster:
Ncl = torch.bincount(cl, minlength=K)
empty = (Ncl == 0)

Using standard PyTorch indexing, you will then be able to assign fresh values to the cluster-less centroids c[empty,:] according to your favorite heuristic.

What do you think?

Best regards, Jean

srikanthmalla commented 4 years ago

Hi @jeanfeydy , Thanks so much for the detailed explanation.

  1. I think this D_ij = (x_i | c_j) is what I'm looking for cosine similarly.

  2. You are right the scores should be output of the network to penalize according to the paper. Then I will not be needing the D_ij dense array for penalization (which I initially thought of doing, which is wrong). But what I'm thinking is to order the top n closest points to a latent feature in the order of proximity using that similarity with-in each cluster during inference (for a regression task). Which I will call in a separate a function, but still the question of computing the dense array in a fast manner (like other keops operations) arises. Please let me know if that is possible or if something is not clear.

Thank you very much, Srikanth

jeanfeydy commented 4 years ago

Good!

For your second point: the .argKmin(K=...) reduction (as well as Kmin_argKmin, etc.) is what you are looking for. It is equivalent to PyTorch's .topk() method, but handles KeOps LazyTensors and is usually much faster: its usage is explained in this first tutorial, this 2D example and this K-NN classification script on MNIST. Just like the other KeOps operations, K-min reduction is both time- and memory-efficient: you never have to compute a full matrix of distances and can process large datasets on-the-fly. Is it suitable for you?

Best regards, Jean

srikanthmalla commented 4 years ago

Hi @jeanfeydy ,

Got it! Thanks so much again :)

Best Regards, Srikanth

jeanfeydy commented 4 years ago

Hi @srikanthmalla ,

You're very welcome: I have just upgraded the K-means tutorial to take your remarks into account 31ef415d9da1d77a19d9ce440d4ea68891bced60 . If you have any other question about the library / syntax, feel free to open a new issue :-)

Best regards, Jean

srikanthmalla commented 4 years ago

Hi @jeanfeydy , Do you also have any example for heirarchical clustering with your library? Please let me know.

Best Regards, Srikanth