DeMoriarty / TorchPQ

Approximate nearest neighbor search with product quantization on GPU in pytorch and cuda
MIT License
214 stars 21 forks source link

About SM Size #14

Closed SEC4SR closed 2 years ago

SEC4SR commented 2 years ago

Hi, thanks very much for sharing this project. I have been looking for a package supporting batch kmeans for a very long period. Very glad to find that TorchPQ supports that (MultiKMeans). Many thanks again.

But I have a question regarding the argument sm_size of initializing MultiKMeans. I know it is Shared Memory Size of CUDA. I am not familiar with CUDA programming and cannot figure out what the default value 48 * 256 * 4 means (the comment in the code does not mention this argument), even after I search on the internet. Could you briefly explain this here? Also, I guess increasing this value can speed up the computation? Am I right? Thanks for your time.

DeMoriarty commented 2 years ago

Hi, thanks for using TorchPQ! sm_size is the maximum shared memory per thread block in bytes. In most of the older GPU architectures (Kepler, Pascal...) this value is 48kB. here you can find specifications of different architectures Version_features_and_specifications

in ComputeCentroidsCuda, you can see the following line

dk * (de + 1) * 4 <= sm_size

dk * (de + 1) * 4 is the actual amount of shared memory the cuda kernel needs, and it should not exceed the device maximum limit sm_size. dk and de are hyperparameters of the kernel that can be tuned, the default value of dk is min(4096, n_clusters), and de = 1. you can experiment with different values for dk, de, and set sm_size to the maximum shared memory per block of your gpu architecture. However I can't guarantee it will bring any significant speedup.

SEC4SR commented 2 years ago

Thanks for your explanation and clarification! Very quick response.

SEC4SR commented 2 years ago

Thanks for your code again.

Problem

I notice that currently, MultiKMeans (more specifically, MultiKMeans.centroids) does not support auto-differentiation, see below:

import torch
from torchpq.clustering import MultiKMeans

B = 128
N = 600
D = 80
r = 0.2
K = int(N * r)

x = torch.randn(B, N, D, device='cuda')
x.requires_grad = True

km = MultiKMeans(n_clusters=K, init_mode='kmeans++')
cluster_ids = km.fit(x.transpose(1, 2).contiguous())
y = km.centroids.transpose(1, 2)
print(cluster_ids, y.requires_grad)
y.backward(torch.ones_like(y))
print(x.grad)

The output is:

image

Solution

Since the gradients of kmeans centroids w.r.t. the input are simple (totally depend on the cluster_ids), to get the gradients, I design a wrapping function for MultiKMeans with the help of torch.autograd.Function:

import torch
from torchpq.clustering import MultiKMeans

class Kmeans_Function(torch.autograd.Function):

    @staticmethod
    def forward(ctx, X, n_clusters, distance='euclidean', init_mode='kmeans++'):
        mul_km = MultiKMeans(n_clusters=n_clusters, distance=distance, init_mode=init_mode)
        cluster_ids = mul_km.fit(X.transpose(1, 2).contiguous())
        ctx.save_for_backward(torch.tensor(X.shape), cluster_ids)
        ctx.constant = n_clusters
        return mul_km.centroids.transpose(1, 2)

    @staticmethod
    def backward(ctx, grad_output): # grad_output: (B, K, D)
        X_shape, cluster_ids = ctx.saved_tensors
        n_clusters = ctx.constant
        grad_X = torch.zeros(torch.Size(X_shape), device=grad_output.device) # (B, N, D)
        B, N, D = X_shape
        for i in range(B):
            cluster_id = cluster_ids[i]
            for j in range(n_clusters):
                ids = torch.where(cluster_id == j)[0]
                if ids.shape[0] > 0:
                    grad_X[i, ids, :] = 1 / ids.shape[0] * grad_output[i, j, :].repeat(1, ids.shape[0], 1)
        return grad_X, None, None, None

def Kmeans(X, n_clusters, distance='euclidean', init_mode='kmeans++'):
    return Kmeans_Function.apply(X, n_clusters, distance, init_mode)

This solution successfully gives me the correct gradient, see below:

import torch
from Kmeans
import time

B = 128
N = 600
D = 80
r = 0.2
K = int(N * r)

x = torch.randn(B, N, D, device='cuda')
x.requires_grad = True

t1 = time.time()
y = Kmeans(x, K)
t2 = time.time()
print(y.requires_grad)
y.backward(torch.ones_like(y))
t3 = time.time()
# print(x.grad)
print(t2-t1,t3-t2)

And the output:

image

But the problem is that the time used for backward is too long, nearly 5x than forward pass, and the reason is that in the backward pass of Kmeans_Function, there is two for loops (I cannot figure it out how to not use these loops).

Wanted Feature

Based on this, I wonder if you plan to support auto-differentiation at the CUDA level? The gradient computation of kmeans is simple (just 1 scaled by the number of vectors assigned in this cluster) and I think supporting auto-differentiation at the CUDA level will make the backward pass more efficient than the naive two loops in my solution. I can also assist to add this feature.

Thanks for your time.

DeMoriarty commented 2 years ago

Hi, currently there's no short term plan to support differentiable KMeans / KNN. But I believe it would definitely be useful for a lot of people, and I will concider putting it on my todo list once I'm less occupied with other projects.

The double for loop in the kmeans backward pass can be eliminated just using pytorch, with the cost of extra memory usage. Here's how the inner loop can be removed:

class Kmeans_Function_2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, n_clusters, distance='euclidean', init_mode='kmeans++', centroids=None):
        mul_km = MultiKMeans(n_clusters=n_clusters, distance=distance, init_mode=init_mode)
        if centroids is None:
          cluster_ids = mul_km.fit(X.transpose(1, 2).contiguous())
          centroids = mul_km.centroids.transpose(1, 2)
        else:
          # if `centroids` is provided, we use it to initialize the kmeans, instead of re-training.
          mul_km.register_buffer("centroids", centroids.transpose(1, 2))
          cluster_ids = mul_km.predict(X.transpose(1, 2).contiguous())
        ctx.save_for_backward(torch.tensor(X.shape), cluster_ids)
        ctx.constant = n_clusters
        return mul_km.centroids.transpose(1, 2)

    @staticmethod
    def backward(ctx, grad_output): # grad_output: (B, K, D)
        X_shape, cluster_ids = ctx.saved_tensors
        n_clusters = ctx.constant
        grad_X = torch.zeros(torch.Size(X_shape), device=grad_output.device) # (B, N, D)
        B, N, D = X_shape
        for i in range(B):
            cluster_id = cluster_ids[i] # (N,)
            expanded_cluster_id = cluster_id[None].repeat(n_clusters, 1) # (K, N)
            mask = expanded_cluster_id == torch.arange(K, device=grad_output.device)[:, None]
            clusters, counts = torch.unique(cluster_id, return_counts=True) 
            cluster_sizes = (mask[clusters, :] * counts[:, None]).sum(dim=0)
            grad_X[i, :, :] = 1 / cluster_sizes[:, None].float() * grad_output[i, cluster_id, :]

        return grad_X, None, None, None, None

def Kmeans2(X, n_clusters, distance='euclidean', init_mode='kmeans++', centroids=None):
    return Kmeans_Function_2.apply(X, n_clusters, distance, init_mode, centroids)

and here's how we can get rid of both for loops:

class Kmeans_Function_3(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, n_clusters, distance='euclidean', init_mode='kmeans++', centroids=None):
        mul_km = MultiKMeans(n_clusters=n_clusters, distance=distance, init_mode=init_mode)
        if centroids is None:
          cluster_ids = mul_km.fit(X.transpose(1, 2).contiguous())
          centroids = mul_km.centroids.transpose(1, 2)
        else:
          mul_km.register_buffer("centroids", centroids.transpose(1, 2))
          cluster_ids = mul_km.predict(X.transpose(1, 2).contiguous())
        ctx.save_for_backward(torch.tensor(X.shape), cluster_ids)
        ctx.constant = n_clusters
        return mul_km.centroids.transpose(1, 2)

    @staticmethod
    def backward(ctx, grad_output): # grad_output: (B, K, D)
        X_shape, cluster_ids = ctx.saved_tensors # cluster_ids: (B, N)
        n_clusters = ctx.constant
        B, N, D = X_shape
        expanded_cluster_id = cluster_ids[:, None].repeat(1, n_clusters, 1) # (B, K, N)
        mask = expanded_cluster_id == torch.arange(K, device=grad_output.device)[None, :, None] #(B, K, N)
        counts = mask.sum(dim=-1) # (B, K)
        cluster_sizes = (mask * counts[:, :, None]).sum(dim=1) # (B, N)
        grad_X = 1 / cluster_sizes[:, :, None].float() * grad_output.gather(dim=1, index=cluster_ids[:, :, None].expand(-1, -1, D))

        return grad_X, None, None, None, None

def Kmeans3(X, n_clusters, distance='euclidean', init_mode='kmeans++', centroids=None):
    return Kmeans_Function_3.apply(X, n_clusters, distance, init_mode, centroids)

And here's the code to time the different methods as well as verify their correctness:

import torch
import time

B = 128
N = 600
D = 80
r = 0.2
K = int(N * r)
iters = 10
use_previous_centroids = False

x = torch.randn(B, N, D, device='cuda')
x.requires_grad = True

def time_kmeans(kmeans_fn, x, iters=1, centroids=None, correct_grad=None):
  if centroids is not None:
    centroids = centroids.detach()

  # the first run will be slower, we need to warmup the gpu
  y = kmeans_fn(x, K)

  # gpu and cpu run asynchronously, we need to synchronize them in order get correct timing.
  torch.cuda.synchronize()
  # time.perf_counter() is similar to time.time(), but has higher resolution. 
  start = time.perf_counter()

  # get average runtime of forward pass over many iterations
  for i in range(iters):
    x.grad = None  # to prevent gradient accumulation
    y = kmeans_fn(x, K, centroids=centroids)

  torch.cuda.synchronize()
  time_forward = time.perf_counter() - start

  torch.cuda.synchronize()
  start = time.perf_counter()

  # get average runtime of backward pass over many iterations
  for i in range(iters):
    y.backward(torch.ones_like(y), retain_graph=True)

  torch.cuda.synchronize()
  time_backward = time.perf_counter() - start
  grad = x.grad.clone()
  print("time spent:", time_forward / iters, time_backward / iters)
  if correct_grad is not None:
    err = (grad - correct_grad).pow(2).mean()
    print("mse", err)

  return y

print("v1: 2 loops")
y = time_kmeans(Kmeans1, x, iters=iters)

if use_previous_centroids:
  # because the centroids of MultiKMeans will be differently initialized each time we call Kmeans_Function, the gradients will 
  #also be different. For consistency, we will use the same centroids in all 3 versions.
  centroids = y
else:
  centroids = None

correct_grad = x.grad.clone()
print("\nv2: 1 loop")
y2 = time_kmeans(Kmeans2, x, iters=iters, centroids=centroids, correct_grad=correct_grad)
print("\nv3: no loop")
y3 = time_kmeans(Kmeans3, x, iters=iters, centroids=centroids, correct_grad=correct_grad)

The backward pass of the 3rd version should be much faster than the other 2, but it's also a lot more memory hungry. it has a space complexity of O(B N K).

SEC4SR commented 2 years ago

Hi, I just noticed your reply, sorry, and thank you very much for your useful suggestion in removing the for loops. It helps me a lot.

SEC4SR commented 2 years ago

I will close this issue. Thanks for your code and kindly help again.