storpipfugl / pykdtree

Fast kd-tree implementation in Python
GNU Lesser General Public License v3.0
218 stars 47 forks source link

Performance with larger k #64

Open ErlerPhilipp opened 2 years ago

ErlerPhilipp commented 2 years ago

Hi,

I'm trying to get a faster spatial data structure than the SciPy KDTree. I'm using the following test code with pyktree==1.3.4 and SciPy==1.7.1:

import numpy as np
from scipy.spatial import kdtree
import time

data_pts = np.random.random(size=(35000, 3)).astype(np.float32)

query_pts = np.random.random(size=(9*6, 3)).astype(np.float32)
num_neighbors = 1000

from pykdtree.kdtree import KDTree
pykd_tree = KDTree(data_pts, leafsize=10)

start_pykd = time.time()
for _ in range(100):
    dist_pykd, idx_pykd = pykd_tree.query(query_pts, k=num_neighbors)
end_pykd = time.time()
print('pykdtree: {}'.format(end_pykd - start_pykd))

scipy_kd_tree = kdtree.KDTree(data_pts, leafsize=10)

start_scipy = time.time()
for _ in range(100):
    dist_scipy, idx_scipy = scipy_kd_tree.query(query_pts, k=num_neighbors)
end_scipy = time.time()
print('scipy: {}'.format(end_scipy - start_scipy))

I get these timings in seconds:

k    | pykdtree | scipy
________________________
1     | 0.009    | 0.0875
10    | 0.0189   | 0.1035
100   | 0.0939   | 0.1889
1000  | 3.269    | 0.952
10000 | >60      | 7.4145

Am I doing anything wrong? Why does pykdtree get so much slower with larger k? Is pykdtree perhaps growing (re-allocating) its output arrays many times?

djhoese commented 2 years ago

Very interesting. What operating system are you using? Version of Python? And how did you install pykdtree?

The reallocation seems like a good guess, but I'd also be curious if OpenMP is doing something weird like deciding to spawn a ton of extra threads. @storpipfugl would have to comment on if the algorithm used is expected to perform well with such a high k. In my own work I've only ever used k<10.

ErlerPhilipp commented 2 years ago

Installed with conda, no OpenMP (at least I didn't do anything on my own). I want it single-threaded since I'd use it in worker processes of my ML stuff. Python 3.8.5 on Windows 10 (64bit) I got similar scaling with Ubuntu in WSL

djhoese commented 2 years ago

The OpenMP should be automatic if it was available when installed. Conda-forge I think includes it in all the builds. You can force it to single thread by setting the environment variable OMP_NUM_THREADS=1.

With your example script I see similar timings (actually worse). Theoretically we should be able to add some options to the cython .pyx file and profile the execution. I don't have time to try this out though. If that is something you'd be willing to look into let me know how I can help.

ErlerPhilipp commented 2 years ago

I just tried it with:

export OMP_NUM_THREADS=1
python test.py

and got the same timings.

I'm no expert for Cyphon or OpenMP but the code looks fine, without obvious re-allocation. Maybe some it's some low-level OMP issue, I don't know. It's probably not worth my time right now. After all, simple brute force might be an option for me.

Anyway, thanks for the very quick confirmation!

ErlerPhilipp commented 2 years ago

Quick update: brute-force is no option for ~35k points.

start_brute = time.time()
for _ in range(100):
    idx_brute = np.empty(shape=(query_pts.shape[0], num_neighbors), dtype=np.uint32)
    dist_brute = np.empty(shape=(query_pts.shape[0], num_neighbors), dtype=np.float32)
    for qi, query_pt in enumerate(query_pts):
        dist_brute_query = cartesian_dist_1_n(query_pt, data_pts)
        sort_ids = np.argsort(dist_brute_query)
        idx_brute_knn = sort_ids[:num_neighbors]
        idx_brute[qi] = idx_brute_knn
        dist_brute[qi] = dist_brute_query[idx_brute_knn]
end_brute = time.time()
print('brute: {}'.format(end_brute - start_brute))

results in ~13.5 sec for k=1000. I guess, I'll stick with SciPy for now.

ingowald commented 2 years ago

Reason this is getting so slow for larger k is that I currently do my "sorted list of currently closest elements" by simply sorting that list every time I put in a new element. For small k that's just as good as using a heap instead, and much simpler to code / easier to read, so that's why i picked that (I hadn't expected anybody to use more than a k of 10!). I'l switch to a heap implmentation, that should fix that.

ErlerPhilipp commented 2 years ago

Oh nice! Would still be useful for me. I will try it when available. I guess deep learning on point clouds adds some unexpected use cases.