scikit-learn-contrib / scikit-learn-extra

scikit-learn contrib estimators
https://scikit-learn-extra.readthedocs.io
BSD 3-Clause "New" or "Revised" License
185 stars 42 forks source link

Add neighbors algorithm based on NSW graphs #143

Open LeoSvalov opened 2 years ago

LeoSvalov commented 2 years ago

Good afternoon!

I would like to add the algorithm to do the approximate nearest neighbors search.

The method is based on Navigable small world graphs (NSW graphs) that tends to demonstrate better performance in the high-dimensional data space [1] in comparison with existing Scikit-Learn KDTree and BallTree methods, starting from data dimension D > 50.

The API of the algorithm is very similar to the existing alternatives, despite the fact that NSWGraph also can be utilized in KNearestNeighbors classifier manner, as the base estimator paradigm (fit/predict) is included.

Possible ways to use the method:

from sklearn_extra.neighbors import NSWGraph
from sklearn.datasets import load_iris
import numpy as np
  1. As object to query k-nearest neighbors.
rng = np.random.RandomState(10)
X = rng.random_sample((50, 128))
nswgraph = NSWGraph()
nswgraph.build(X)
X_val = rng.random_sample((5, 128))
dists, inds = nswgraph.query(X_val, k=3)
  1. As neighbors estimator with taking into account the target classes of the data.
    X,y = load_iris(return_X_y=True)
    estimator = NSWGraph()
    estimator.fit(X,y)
    y_pred = estimator.predict(X)

References

[1] Malkov, Y., Ponomarenko, A., Logvinov, A., & Krylov, V. (2014). Approximate nearest neighbor algorithm based on navigable small world graphs. Information Systems, 45, 61-68.

ogrisel commented 2 years ago

Thanks for the contribution! While I am not yet sure if it would meet a consensus of maintainers to be accepted in the scikit-learn code base, it would surely help to run some benchmarks.

If the speed of your PR can demonstrate to be approximately competitive with alternative implementations, it would surely help convince maintainers that it is worth investing their time to review the PR and accept the long term maintenance burden that will come with a new method.

Ideally the benchmarks could be based on this existing infrastructure:

In particular I would be interested in a comparison with nswlib's implementation and alternative method not based on NSW graphs such as https://github.com/lmcinnes/pynndescent.

ogrisel commented 2 years ago

I just realised that this is not the scikit-learn/scikit-learn repo but the scikit-learn-extra repo as I arrived to this PR from the scikit-learn/scikit-learn#23450 issue from the main scikit-learn issue tracker.

I think it would be great to have an implementation of NSW nearest neighbors in scikit-learn-extra. But before reviewing this PR, I would like to see some performance benchmark results as requested above.