scikit-learn / scikit-learn

scikit-learn: machine learning in Python
https://scikit-learn.org
BSD 3-Clause "New" or "Revised" License
60.34k stars 25.44k forks source link

BallTree query match time is O(n) not O(log(n)) #19066

Open simsim314 opened 3 years ago

simsim314 commented 3 years ago

I've run performance analysis on matching NN with BallTree (same with KDTree), and the matching time is linear to number of elements, and should be O(log(n)).

Here are the result of benchmark: num_elements, match_time 10000 0.09097146987915039 20000 0.18194293975830078 40000 0.3668830394744873 80000 0.7527577877044678

Here is my code:


from sklearn.neighbors import BallTree
import numpy as np 
import time 

def tree_perf(tree_size):
    X = np.random.rand(tree_size, 512)
    Y1 = np.random.rand(1, 512)
    Y2 = np.random.rand(10, 512)

    ts = time.time()
    kdt = BallTree(X, leaf_size=30, metric='euclidean')
    load_tree = time.time() - ts
    num_nn = 1
    ts = time.time()
    vs = kdt.query(Y1, k=num_nn, return_distance=True)
    match1 = time.time() - ts
    ts = time.time()
    vs = kdt.query(Y2, k=num_nn, return_distance=True)
    match10 = time.time() - ts
    print(tree_size, load_tree, match1, match10)

print("num_elements", "load_tree", "match_1", "match_10")

for i in range(100):
    tree_size = 10000 + i * 10000
    tree_perf(tree_size)
TomDLT commented 3 years ago

From the documentation:

  • Brute force query time grows as O(DN)
  • Ball tree query time grows as approximately O(DlogN)
  • KD tree query time changes with D in a way that is difficult to precisely characterise. For small D (less than 20 or so) the cost is approximately O(DlogN), and the KD tree query can be very efficient. For larger D, the cost increases to nearly O(DN), and the overhead due to the tree structure can lead to queries which are slower than brute force.

Playing a bit with it, it seems that the description of nearly O(DN) for large D is also valid for Ball tree, which seems a bit weird to me. For D = 5 Figure_1 For D = 50 Figure_2

For comparison, annoy leads to the expected O(logN): Figure_3

I don't know if this is expected behavior (then we should update the doc), or if it is a bug (then we should fix it). Here is the script I used for reference:

```py import time import numpy as np import matplotlib.pyplot as plt from sklearn.neighbors import KNeighborsTransformer import annoy from sklearn.base import BaseEstimator, TransformerMixin from scipy.sparse import csr_matrix class AnnoyTransformer(TransformerMixin, BaseEstimator): """Wrapper for using annoy.AnnoyIndex as sklearn's KNeighborsTransformer""" def __init__(self, n_neighbors=5, metric='euclidean', n_trees=10, search_k=-1): self.n_neighbors = n_neighbors self.n_trees = n_trees self.search_k = search_k self.metric = metric def fit(self, X): self.n_samples_fit_ = X.shape[0] metric = self.metric if self.metric != 'sqeuclidean' else 'euclidean' self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=metric) for i, x in enumerate(X): self.annoy_.add_item(i, x.tolist()) self.annoy_.build(self.n_trees) return self def transform(self, X): return self._transform(X) def fit_transform(self, X, y=None): return self.fit(X)._transform(X=None) def _transform(self, X): """As `transform`, but handles X is None for faster `fit_transform`.""" n_samples_transform = self.n_samples_fit_ if X is None else X.shape[0] # For compatibility reasons, as each sample is considered as its own # neighbor, one extra neighbor will be computed. n_neighbors = self.n_neighbors + 1 indices = np.empty((n_samples_transform, n_neighbors), dtype=int) distances = np.empty((n_samples_transform, n_neighbors)) if X is None: for i in range(self.annoy_.get_n_items()): ind, dist = self.annoy_.get_nns_by_item( i, n_neighbors, self.search_k, include_distances=True) indices[i], distances[i] = ind, dist else: for i, x in enumerate(X): indices[i], distances[i] = self.annoy_.get_nns_by_vector( x.tolist(), n_neighbors, self.search_k, include_distances=True) if self.metric == 'sqeuclidean': distances **= 2 indptr = np.arange(0, n_samples_transform * n_neighbors + 1, n_neighbors) kneighbors_graph = csr_matrix( (distances.ravel(), indices.ravel(), indptr), shape=(n_samples_transform, self.n_samples_fit_)) return kneighbors_graph n_neighbors = 5 n_samples_query = 1 n_features = 5 for model in [ KNeighborsTransformer(n_neighbors=n_neighbors, algorithm='ball_tree'), KNeighborsTransformer(n_neighbors=n_neighbors, algorithm='kd_tree'), KNeighborsTransformer(n_neighbors=n_neighbors, algorithm='brute'), AnnoyTransformer(n_neighbors=n_neighbors), ]: def bench(n_samples_fit): X = np.random.rand(n_samples_fit, n_features) Y = np.random.rand(n_samples_query, n_features) start = time.time() kdt = model.fit(X) fit_time = time.time() - start start = time.time() _ = kdt.transform(Y) query_time = time.time() - start return n_samples_fit, fit_time, query_time results = [] for n_samples_fit in np.int_(np.logspace(2, 5, 10)): results.append(bench(n_samples_fit)) results = np.array(results) plt.loglog(results[:, 0], results[:, 1], 'o-', label='fit') plt.loglog(results[:, 0], results[:, 2], 'o-', label='query') plt.loglog(results[:, 0], results[:, 0] / 1e5, '--', label='linear') plt.legend() plt.ylabel('Time (sec)') plt.xlabel('n_samples_fit') plt.title(str(model)) plt.show() ```
simsim314 commented 3 years ago

Many thanks! Yes I would suggest to add this fact to the documentation - also the annoy alternative is not mentioned as a viable option (and in many cases would be preferred). As an improvement suggestion - add the annoy algorithm to scikit-learn. And to avoid some "magic parameters" like num_trees, one should only pass a parameter of "probability of mistake", while the num_trees should be calculated from this parameter. I guess for practical usage it's many time more useful to use annoy.

TomDLT commented 3 years ago

There is no plan to include approximate nearest neighbors in scikit-learn, I only added the wrapper around annoy as a comparison. I agree we could add a reference to annoy (and nmslib, faiss, ...) in the documentation.

The O(n) query time in our BallTree and KDTree is still a bit weird though.

jnothman commented 3 years ago

The O(n) query time in our BallTree and KDTree is still a bit weird though.

I agree. Let's label this as a bug until someone investigates. I would have expected these were benchmarked when implemented, but this all deserves a little looking into.

jnothman commented 3 years ago

is this a duplicate of #7687, or a different case?

simsim314 commented 3 years ago

No. It's similiar but different. He is talking about tree load time and my report is about the query time. His issue is also reproducable on single data case, my issue is reproducable on random case i.e. most cases.

thomasjpfan commented 3 years ago

During query every leaf will compute a rdist:

https://github.com/scikit-learn/scikit-learn/blob/b94332434d0117e3d86407560a206d1c7bee1c81/sklearn/neighbors/_binary_tree.pxi#L1824-L1829

This rdist scales with a factor of n_features * n_samples_in_node.

Edit: Added the number of samples in the node

simsim314 commented 3 years ago

Thomas it's not relevant. We measured with constant amount of features called d, and we don't see query in N number of points, when it's several orders of magnute higher than d. We took 5, 50, 500 fetures while we had tens to hunders of thousands of points, it should be logarithmic in N acording to the documentation and linear in d, and we see linear performance in N. Check Tom's graphs they are better than my verbal report, but my report shows linear in N query time with 500 features.

thomasjpfan commented 3 years ago

When I remove the self.rdist calculation, I am able to get the desired behavior: (This is not the right solution and only use for debugging):

n_features=50

Figure_1

Without this computation I referenced above the query is O(n_samples * log(n_features)). The additional computation loops through the samples in the node and computes the distance between the sample and the query point. (There is also a cost with placing elements on the heap.)

Edit: I think I am off with my complexity analysis.

simsim314 commented 3 years ago

I don't see what you mean by self.rdist? Can you please post your code?

As far as I understand usualy the tree algorithms queries works in O(log(num_samples)) As function of num_samples, this is why people use trees. The dimentionality adds linear factor, but usually as long as the dim is not high no one cares about it. We still never saw logarithmic behaviour, and your graph is also a function of num_samples and it looks linear, if you can either increas num_features to 500 or to add another order of magnitide of num_samples you will see the linearity effect more clear. 1ms shows you probably have a good hardware, try to see what happens when you reach a seconds for query.

From the documentation

gkaranikas commented 3 years ago

I don't think the linear cost of the BallTree query is a bug.

I think we shouldn't expect logarithmic behaviour when D = 512, because search algorithms that use distance are affected by the curse of dimensionality. In high dimensions, Euclidean (or L^p) distance is not a good metric for discriminating between data, which can be shown with concentration inequalities. Empirically, this paper [1] examined the performance of nearest neighbor search algorithms (including BallTree-like algorithms, cf. Sec. 3.3.2) in high dimensions. This comment from page 5 seems to fit what @simsim314 and @TomDLT observed:

Conventional data- and space-partitioning structures are out-performed by a sequential scan already at dimensionality of around 10 or higher.

In light of this, I think the correct course is modifying the documentation to note that the O(D log N) behaviour of the BallTree query does not hold for large D.

Sources: [1] Weber, Schek & Blott (1998). A quantitative analysis and performance study for similarity-search methods in high-dimensional spaces. VLDB. https://www.cs.utexas.edu/~grauman/courses/spring2007/395T/papers/weberetal1998.pdf

simsim314 commented 3 years ago

@gkaranikas What's the point of balltree then? It stated In the documentation, that kdtree has problems with curse of dimentionality this is why you need the balltree. From documentation:

"To address the inefficiencies of KD Trees in higher dimensions, the ball tree data structure was developed. Where KD trees partition data along Cartesian axes, ball trees partition data in a series of nesting hyper-spheres. This makes tree construction more costly than that of the KD tree, but results in a data structure which can be very efficient on highly structured data, even in very high dimensions."

And we start to observe linearity in N for as small D as 50. I'm almost certain it's just a bug, somewhere N operations are used. Even one copy of the input somewhere will cause this.

gkaranikas commented 3 years ago

What's the point of balltree then?

@simsim314 I think it's supposed to be better than KD trees when the dimension is "high" but not very high. For very high dimensions, I don't know if an exact nearest neighbors algorithm can escape the curse of dimensionality and maintain logarithmic performance. I am aware that the documentation you quoted says "very high dimensions", but that might be misleading.

On another note, that quote does say "on highly structured data" and this thread has so far been limited to random uniformly distributed data. The next paragraph in the documentation reiterates:

it can out-perform a KD-tree in high dimensions, though the actual performance is highly dependent on the structure of the training data

I'm still not convinced it's a bug. When you think about how balltree works, it makes sense. For example, suppose we have a query point q and a left and right node, N1 and N2. Then the nodes determine intervals (l1, u1) and (l2, u2) such that the distance |q-x| must belong to the interval for all x in the corresponding node. Now when the dimension is very high, there is very likely some overlap between the intervals (l1, u1) and (l2, u2), due to the concentration of distance phenomenon. Which means both nodes N1 and N2 have to be searched.

simsim314 commented 3 years ago

@gkaranikas if very high which can't escape the curse of dimensionality is 50 then I think it should be mentioned indeed.

If you are certain there is no bug in the code, and no where there is no O(N) copy somewhere or any bug of the sort, then yes only the documentation should be modified.

Regarding on highly structured data - this means that even on structured data where usual KD Tree can fail for some reason the ball tree will not. Random data should be the simplest case, not the most complex one.

As far as I understand the ball tree is dividing the space in a way that reduces the probability by half. This means that the probability of |q-x| to belong to the interval is 0.5, this is not so hard to achieve in higher dimensions as well - as long as you choose your query points to belong to a small enough ball (the triangle inequality works in higher dimensions too). I don't see how more dimensions change this fact. The gray area of some sort should have the usual kd tree probability i.e. the probability of two random points to be at exactly distance d +/- eps should be pretty low, and all the other cases should either belong to the ball or be outside of it - then you can apply the same logic again to both of the points sets, those which inside the ball and those which are outside of it. The only case it's not working and you have to go on both nodes is when the distance between the center of the ball and q is in the gray area of radius i.e. q very close to the edge of the ball, and this should happen pretty rarely at least to random points from the same distribution, this is why we query many points and average the performance.

simsim314 commented 3 years ago

I was thinking about a way to validate if this has something to do with curse of dimensionality or a bug in copying N elements. So I took D = 5, I guess this is low enough. I was matching 100K elements, and the matching time is still no where near linear to Log(N). But it does look sub linear to N (like O(N / log(N) or something)

Match 100K (sec) vs  Log2 Num Points


from sklearn.neighbors import BallTree
import numpy as np 
import time 

def tree_perf(tree_size):
    X = np.random.rand(tree_size, 5)
    Y1 = np.random.rand(1, 5)
    Y2 = np.random.rand(100000, 5)

    ts = time.time()
    kdt = BallTree(X, leaf_size=30, metric='euclidean')
    load_tree = time.time() - ts
    num_nn = 1
    ts = time.time()
    vs = kdt.query(Y1, k=num_nn, return_distance=True)
    match1 = time.time() - ts
    ts = time.time()
    vs = kdt.query(Y2, k=num_nn, return_distance=True)
    match10 = time.time() - ts
    print(tree_size, match10)

  print("num_elements", "match_10")

tree_size = 128
for i in range(100):
    tree_size *= 2
    tree_perf(tree_size)
Log2 Num Points Match 100K (sec)
8 0.4
9 0.6
10 0.7
11 1.0
12 1.3
13 2.1
14 2.7
15 3.0
16 4.6
17 8.2
18 12.8
19 16.3
20 17.0
21 22.8
22 30.7
23 43.7
24 62.9
gkaranikas commented 3 years ago

@simsim314 you said:

The only case it's not working and you have to go on both nodes is when the distance between the center of the ball and q is in the gray area of radius i.e. q very close to the edge of the ball, and this should happen pretty rarely at least to random points from the same distribution, this is why we query many points and average the performance.

I'm afraid this is not true. In fact, it's the opposite. If your data is uniformly distributed then most points are close to the boundary of the ball in high dimensions, i.e. it is extremely likely that q is in the "gray area of radius".

Here's the reasoning. Consider a D dimensional ball of radius 1. The proportion of its volume within a distance of say 0.01 from its surface is 1 - 0.99^D. So that gives the following proportions:

D = 5 --> 5% D = 50 --> 40% D = 500 --> 99%

This is a manifestation of the curse of dimensionality.

I agree that running tests to determine if there is a copying bug is a good idea. However, there is no evidence of a bug and we know the curse of dimensionality can cause this. So I would apply Occam's razor.

simsim314 commented 3 years ago

@gkaranikas I've shown the same bug for D=5 i.e. there are good evidence for a bug now. The performance is not linear to Log(N) for D=5.

ogrisel commented 2 years ago

The performance is not linear to Log(N) for D=5.

Coming back to this, the first plot of https://github.com/scikit-learn/scikit-learn/issues/19066#issuecomment-751521686 looks good, no?

Feel free to repeat similar benchmarks with different data distribution and compare to kd-tree to see if the ball-tree query-time performance is problematic or not.