unlimblue / KNN_CUDA

pytorch knn [cuda version]
292 stars 37 forks source link

Process time is slower than the implementation of pytorch #23

Open zhangtingyu11 opened 2 months ago

zhangtingyu11 commented 2 months ago

I have conducted a test comparing the batch size of 1 with the code I implemented using the PyTorch API. The code is provided below:

import torch
import time
from knn_cuda import KNN
def get_millisecond():
    t = time.time()
    return int(round(t * 1000))
def find_two_closest_points(A, B):
    distances = torch.cdist(A, B, p=2)
    closest_distances, closest_indices = torch.topk(distances, 2, dim=-1, largest=False)
    return closest_distances
knn = KNN(k=2, transpose_mode=True)

tensor_A_knn_cuda = torch.randn((1, 1000, 128)).cuda()
tensor_B_knn_cuda = torch.randn((1, 1000, 128)).cuda()
tensor_A = torch.randn((1000, 128)).cuda()
tensor_B = torch.randn((1000, 128)).cuda()

start_time = get_millisecond()
torch.cuda.synchronize()
for _ in range(10000):
    distance = knn(tensor_A_knn_cuda, tensor_B_knn_cuda)
torch.cuda.synchronize()
end_time = get_millisecond()
print("knn_cuda process time: {}".format(end_time-start_time))

start_time = get_millisecond()
torch.cuda.synchronize()
for _ in range(10000):
    distance = find_two_closest_points(tensor_A, tensor_B)
torch.cuda.synchronize()
end_time = get_millisecond()
print("pytorch process time: {}".format(end_time-start_time))

The processing times are displayed in the image below. image The processing time appears to be slower than the PyTorch API. Is this a reasonable outcome?