Closed leechangyo closed 3 years ago
Thank you for the your interest!
I've started digging into it, but didn't found a way to get indices from _QuadTree (distances are matched). My code:
import numpy as np
from sklearn import neighbors
from sklearn.neighbors._quad_tree import _QuadTree
class QuadTreeWithSKLEARNInterface:
def __init__(self, n_neighbors=3):
self._qt = None
self._n_neighbors = n_neighbors
def fit(self, X, verbose=1):
n_dim = X.shape[1]
assert 2 <= n_dim <= n_dim
self._qt = _QuadTree(n_dimensions=n_dim, verbose=verbose)
self._qt.build_tree(X)
self._X = X
def kneighbors(self, X_query):
assert self._qt is not None, "call fit() before"
X = self._X
idx, summary = self._qt._py_summarize(X_query, X, 0.0)
summary = np.array(summary).reshape(-1, 4)
distances = summary[:, 2]
indices = distances.argsort()[:self._n_neighbors]
return distances[indices]
X_index = np.random.uniform(size=(10, 2)).astype(np.float32)
X_query = np.random.uniform(size=(1, 2)).astype(np.float32)
tree = QuadTreeWithSKLEARNInterface()
tree.fit(X_index)
r = tree.kneighbors(X_query[:, 0])
print(r)
tree = neighbors.NearestNeighbors(n_neighbors=3, algorithm="brute")
tree.fit(X_index)
a1, a2 = tree.kneighbors(X_query)
a1 = a1 * a1
print(a1)
print(a2)
Could you please post your solution?
thank you for your work ! my case i just want to compare time cost of built tree and queries. that's why i simply modified it on benchmark.py but i really appreciate that you work for me. it is very helpful ! thanks!
import time
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import pynanoflann
from contexttimer import Timer
from sklearn import neighbors
n_index_points = 20000
n_query_points = 1000
n_repititions = 5
data_dim = 3
n_neighbors = 100
index_type = np.float32
data = np.random.uniform(0, 100, size=(n_index_points, data_dim)).astype(index_type)
queries = np.random.uniform(0, 100, size=(n_query_points, data_dim)).astype(index_type)
algs = {
'sklearn_brute': neighbors.NearestNeighbors(n_neighbors=n_neighbors, algorithm='brute'),
# 'quad_tree': neighbors.NearestNeighbors(n_neighbors=n_neighbors, algorithm='quad_tree'),
'quad_tree' : neighbors._QuadTree(data_dim, verbose =0),
'sklearn_kd_tree': neighbors.NearestNeighbors(n_neighbors=n_neighbors, algorithm='kd_tree'),
'pynanoflann': pynanoflann.KDTree(n_neighbors=n_neighbors),
}
results = []
for rep in range(n_repititions):
for alg_name, nn in algs.items():
if alg_name == 'quad_tree':
with Timer() as index_build_time:
nn.build_tree(data)
with Timer() as query_time:
for i in data:
idx = nn.get_cell(i)
else:
with Timer() as index_build_time:
nn.fit(data)
with Timer() as query_time:
dist, idx = nn.kneighbors(queries)
results.append((alg_name, index_build_time.elapsed, query_time.elapsed))
df = pd.DataFrame(results, columns=['Algorithm', 'Index build time, second', 'Query time, second'])
print(df)
fig, ax = plt.subplots(figsize=(18, 6))
sns.barplot(data=df, x='Algorithm', y=df.columns[2], ax=ax, ci=None)
ax.set_yscale("log", basey=4)
ylabels = ['{:.4f}'.format(x) for x in ax.get_yticks()]
ax.set_yticklabels(ylabels)
ax.set_title(f'n_index_points={n_index_points}, n_query_points={n_query_points}, dim={data_dim}')
plt.grid()
plt.savefig('benchmark_query.png')
fig, ax = plt.subplots(figsize=(18, 6))
sns.barplot(data=df[df.Algorithm != 'sklearn_brute'], x='Algorithm', y=df.columns[1], ax=ax, palette=['C1', 'C2', 'C3'], ci=None)
ax.set_title(f'n_index_points={n_index_points}, dim={data_dim}')
plt.grid()
plt.savefig('benchmark_index.png')
plt.show()
never mind! i solved it, thanks!