u1234x1234 / pynanoflann

Unofficial python wrapper to the nanoflann k-d tree
BSD 2-Clause "Simplified" License
33 stars 8 forks source link

how can i added quad_tree nearest method in benchmark? #9

Closed leechangyo closed 3 years ago

leechangyo commented 3 years ago

never mind! i solved it, thanks!

u1234x1234 commented 3 years ago

Thank you for the your interest!

I've started digging into it, but didn't found a way to get indices from _QuadTree (distances are matched). My code:

import numpy as np
from sklearn import neighbors
from sklearn.neighbors._quad_tree import _QuadTree

class QuadTreeWithSKLEARNInterface:
    def __init__(self, n_neighbors=3):
        self._qt = None
        self._n_neighbors = n_neighbors

    def fit(self, X, verbose=1):
        n_dim = X.shape[1]
        assert 2 <= n_dim <= n_dim
        self._qt = _QuadTree(n_dimensions=n_dim, verbose=verbose)
        self._qt.build_tree(X)
        self._X = X

    def kneighbors(self, X_query):
        assert self._qt is not None, "call fit() before"

        X = self._X
        idx, summary = self._qt._py_summarize(X_query, X, 0.0)

        summary = np.array(summary).reshape(-1, 4)
        distances = summary[:, 2]
        indices = distances.argsort()[:self._n_neighbors]

        return distances[indices]

X_index = np.random.uniform(size=(10, 2)).astype(np.float32)
X_query = np.random.uniform(size=(1, 2)).astype(np.float32)

tree = QuadTreeWithSKLEARNInterface()
tree.fit(X_index)

r = tree.kneighbors(X_query[:, 0])
print(r)

tree = neighbors.NearestNeighbors(n_neighbors=3, algorithm="brute")
tree.fit(X_index)
a1, a2 = tree.kneighbors(X_query)
a1 = a1 * a1

print(a1)
print(a2)

Could you please post your solution?

leechangyo commented 3 years ago

thank you for your work ! my case i just want to compare time cost of built tree and queries. that's why i simply modified it on benchmark.py but i really appreciate that you work for me. it is very helpful ! thanks!

import time
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import pynanoflann
from contexttimer import Timer
from sklearn import neighbors

n_index_points = 20000
n_query_points = 1000
n_repititions = 5
data_dim = 3
n_neighbors = 100
index_type = np.float32

data = np.random.uniform(0, 100, size=(n_index_points, data_dim)).astype(index_type)
queries = np.random.uniform(0, 100, size=(n_query_points, data_dim)).astype(index_type)

algs = {
    'sklearn_brute': neighbors.NearestNeighbors(n_neighbors=n_neighbors, algorithm='brute'),
    # 'quad_tree': neighbors.NearestNeighbors(n_neighbors=n_neighbors, algorithm='quad_tree'),
    'quad_tree' : neighbors._QuadTree(data_dim, verbose =0),
    'sklearn_kd_tree': neighbors.NearestNeighbors(n_neighbors=n_neighbors, algorithm='kd_tree'),
    'pynanoflann': pynanoflann.KDTree(n_neighbors=n_neighbors),
}

results = []
for rep in range(n_repititions):
    for alg_name, nn in algs.items():

        if alg_name == 'quad_tree':
            with Timer() as index_build_time:
                nn.build_tree(data)

            with Timer() as query_time:
                for i in data:
                    idx = nn.get_cell(i)

        else:
            with Timer() as index_build_time:
                nn.fit(data)

            with Timer() as query_time:
                dist, idx = nn.kneighbors(queries)

        results.append((alg_name, index_build_time.elapsed, query_time.elapsed))

df = pd.DataFrame(results, columns=['Algorithm', 'Index build time, second', 'Query time, second'])
print(df)

fig, ax = plt.subplots(figsize=(18, 6))
sns.barplot(data=df, x='Algorithm', y=df.columns[2], ax=ax, ci=None)
ax.set_yscale("log", basey=4)
ylabels = ['{:.4f}'.format(x) for x in ax.get_yticks()]
ax.set_yticklabels(ylabels)
ax.set_title(f'n_index_points={n_index_points}, n_query_points={n_query_points}, dim={data_dim}')
plt.grid()
plt.savefig('benchmark_query.png')

fig, ax = plt.subplots(figsize=(18, 6))
sns.barplot(data=df[df.Algorithm != 'sklearn_brute'], x='Algorithm', y=df.columns[1], ax=ax, palette=['C1', 'C2', 'C3'], ci=None)
ax.set_title(f'n_index_points={n_index_points}, dim={data_dim}')
plt.grid()
plt.savefig('benchmark_index.png')

plt.show()