rapidsai / cuml

cuML - RAPIDS Machine Learning Library
https://docs.rapids.ai/api/cuml/stable/
Apache License 2.0
4.23k stars 532 forks source link

[BUG] KNeighbors throwing error for n_neighbors >= total samples/2 #4708

Closed pj-mathematician closed 2 years ago

pj-mathematician commented 2 years ago

Describe the bug

When trying to get kneighbors using the algorithms provided in cuml.neighbors, it works for n_neighbors less than half of total samples, and throws error otherwise.

Steps/Code to reproduce bug

I have tested this for all the algorithms present as of now in cuml.neighbors, and each of them throw the same error.

from cuml.neighbors import NearestNeighbors, KNeighborsClassifier, KNeighborsRegressor
import cudf
d = {
    'id':[1,2,3,4,5,6],
    'latitude':[50,-22,13,37,43,14],
    'longitude':[3,-43,100,27,-4,121],
}
df = cudf.DataFrame(d) # total samples = 6
coo_cols = ["latitude", "longitude"]

N = 3

matcher_nn = NearestNeighbors(n_neighbors=N)
matcher_nn.fit(df[coo_cols])
distances_nn, indices_nn = matcher_nn.kneighbors(df[coo_cols])

matcher_knc = KNeighborsClassifier(n_neighbors=N)
matcher_knc.fit(df[coo_cols], df.index)
distances_knc, indices_knc = matcher_nn.kneighbors(df[coo_cols])

matcher_knr = KNeighborsRegressor(n_neighbors=N)
matcher_knr.fit(df[coo_cols], df.index)
distances_knr, indices_knr = matcher_nn.kneighbors(df[coo_cols])

distances_nn==distances_knc

Each of them throwing the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_52/2768622050.py in <module>
     13 matcher_nn = NearestNeighbors(n_neighbors=N)
     14 matcher_nn.fit(df[coo_cols])
---> 15 distances_nn, indices_nn = matcher_nn.kneighbors(df[coo_cols])
     16 
     17 matcher_knc = KNeighborsClassifier(n_neighbors=N)

/opt/conda/lib/python3.7/site-packages/cuml/internals/api_decorators.py in inner_get(*args, **kwargs)
    584 
    585                 # Call the function
--> 586                 ret_val = func(*args, **kwargs)
    587 
    588             return cm.process_return(ret_val)

cuml/neighbors/nearest_neighbors.pyx in cuml.neighbors.nearest_neighbors.NearestNeighbors.kneighbors()

cuml/neighbors/nearest_neighbors.pyx in cuml.neighbors.nearest_neighbors.NearestNeighbors._kneighbors()

cuml/neighbors/nearest_neighbors.pyx in cuml.neighbors.nearest_neighbors.NearestNeighbors._kneighbors_dense()

RuntimeError: exception occured! file=_deps/raft-src/cpp/include/raft/spatial/knn/detail/ball_cover.cuh line=326: number of landmark samples must be >= k
Obtained 64 stack frames
#0 in /opt/conda/lib/python3.7/site-packages/cuml/common/../../../../libcuml++.so(_ZN4raft9exception18collect_call_stackEv+0x38) [0x7f688ca4d7b8]
#1 in /opt/conda/lib/python3.7/site-packages/cuml/common/../../../../libcuml++.so(_ZN4raft7spatial3knn6detail13rbc_knn_queryIlfjNS2_13EuclideanFuncEEEvRKNS_8handle_tERNS1_14BallCoverIndexIT_T0_T1_EESB_PKSA_SB_PS9_PSA_T2_bf+0x85a) [0x7f688cd0e3ea]
#2 in /opt/conda/lib/python3.7/site-packages/cuml/common/../../../../libcuml++.so(_ZN4raft7spatial3knn13rbc_knn_queryIlfjEEvRKNS_8handle_tERNS1_14BallCoverIndexIT_T0_T1_EES9_PKS8_S9_PS7_PS8_bf+0x4a) [0x7f688cd0e5fa]
#3 in /opt/conda/lib/python3.7/site-packages/cuml/common/../../../../libcuml++.so(_ZN2ML13rbc_knn_queryERKN4raft8handle_tERNS0_7spatial3knn14BallCoverIndexIlfjEEjPKfjPlPf+0x17) [0x7f688ccdbc37]
#4 in /opt/conda/lib/python3.7/site-packages/cuml/neighbors/nearest_neighbors.cpython-37m-x86_64-linux-gnu.so(+0x465b0) [0x7f68782fc5b0]
#5 in /opt/conda/lib/python3.7/site-packages/cuml/neighbors/nearest_neighbors.cpython-37m-x86_64-linux-gnu.so(+0x2c9c6) [0x7f68782e29c6]
#6 in /opt/conda/lib/python3.7/site-packages/cuml/neighbors/nearest_neighbors.cpython-37m-x86_64-linux-gnu.so(+0x366d7) [0x7f68782ec6d7]
#7 in /opt/conda/bin/python(+0x11f11c) [0x5608d365811c]
#8 in /opt/conda/lib/python3.7/site-packages/cuml/neighbors/nearest_neighbors.cpython-37m-x86_64-linux-gnu.so(+0x2f9e2) [0x7f68782e59e2]
#9 in /opt/conda/bin/python(PyObject_Call+0x61) [0x5608d363c101]
#10 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x2003) [0x5608d36e7db3]
#11 in /opt/conda/bin/python(_PyEval_EvalCodeWithName+0x242) [0x5608d362aea2]
#12 in /opt/conda/bin/python(_PyFunction_FastCallKeywords+0x320) [0x5608d3671370]
#13 in /opt/conda/bin/python(+0x13b0d8) [0x5608d36740d8]
#14 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x4f0a) [0x5608d36eacba]
#15 in /opt/conda/bin/python(_PyEval_EvalCodeWithName+0x242) [0x5608d362aea2]
#16 in /opt/conda/bin/python(PyEval_EvalCodeEx+0x39) [0x5608d362c0b9]
#17 in /opt/conda/bin/python(PyEval_EvalCode+0x1b) [0x5608d370b15b]
#18 in /opt/conda/bin/python(+0x247fa2) [0x5608d3780fa2]
#19 in /opt/conda/bin/python(_PyMethodDef_RawFastCallKeywords+0x68) [0x5608d3672f58]
#20 in /opt/conda/bin/python(+0x13b1e8) [0x5608d36741e8]
#21 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0xaab) [0x5608d36e685b]
#22 in /opt/conda/bin/python(+0x16e553) [0x5608d36a7553]
#23 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x1c60) [0x5608d36e7a10]
#24 in /opt/conda/bin/python(+0x16e553) [0x5608d36a7553]
#25 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x1c60) [0x5608d36e7a10]
#26 in /opt/conda/bin/python(+0x16e553) [0x5608d36a7553]
#27 in /opt/conda/bin/python(_PyMethodDescr_FastCallKeywords+0x374) [0x5608d3673dd4]
#28 in /opt/conda/bin/python(+0x13b1ad) [0x5608d36741ad]
#29 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0xb41) [0x5608d36e68f1]
#30 in /opt/conda/bin/python(_PyFunction_FastCallKeywords+0x184) [0x5608d36711d4]
#31 in /opt/conda/bin/python(+0x13b0d8) [0x5608d36740d8]
#32 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0xaab) [0x5608d36e685b]
#33 in /opt/conda/bin/python(_PyFunction_FastCallKeywords+0x184) [0x5608d36711d4]
#34 in /opt/conda/bin/python(+0x13b0d8) [0x5608d36740d8]
#35 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0xb41) [0x5608d36e68f1]
#36 in /opt/conda/bin/python(_PyEval_EvalCodeWithName+0x242) [0x5608d362aea2]
#37 in /opt/conda/bin/python(_PyFunction_FastCallDict+0x35f) [0x5608d362c41f]
#38 in /opt/conda/bin/python(+0x11f093) [0x5608d3658093]
#39 in /opt/conda/bin/python(PyObject_Call+0x61) [0x5608d363c101]
#40 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x2003) [0x5608d36e7db3]
#41 in /opt/conda/bin/python(_PyEval_EvalCodeWithName+0x242) [0x5608d362aea2]
#42 in /opt/conda/bin/python(_PyFunction_FastCallKeywords+0x320) [0x5608d3671370]
#43 in /opt/conda/bin/python(+0x13b0d8) [0x5608d36740d8]
#44 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x172a) [0x5608d36e74da]
#45 in /opt/conda/bin/python(+0x16e553) [0x5608d36a7553]
#46 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x1c60) [0x5608d36e7a10]
#47 in /opt/conda/bin/python(+0x16e553) [0x5608d36a7553]
#48 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x1c60) [0x5608d36e7a10]
#49 in /opt/conda/bin/python(+0x16e553) [0x5608d36a7553]
#50 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x1c60) [0x5608d36e7a10]
#51 in /opt/conda/bin/python(+0x16e553) [0x5608d36a7553]
#52 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x1c60) [0x5608d36e7a10]
#53 in /opt/conda/bin/python(+0x16e553) [0x5608d36a7553]
#54 in /opt/conda/lib/python3.7/lib-dynload/_asyncio.cpython-37m-x86_64-linux-gnu.so(+0x9c88) [0x7f699d3eac88]
#55 in /opt/conda/bin/python(_PyObject_FastCallKeywords+0x47b) [0x5608d367372b]
#56 in /opt/conda/bin/python(+0x1f8d9f) [0x5608d3731d9f]
#57 in /opt/conda/bin/python(PyCFunction_Call+0x17c) [0x5608d363bc1c]
#58 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0x6381) [0x5608d36ec131]
#59 in /opt/conda/bin/python(_PyFunction_FastCallKeywords+0x184) [0x5608d36711d4]
#60 in /opt/conda/bin/python(+0x13b0d8) [0x5608d36740d8]
#61 in /opt/conda/bin/python(_PyEval_EvalFrameDefault+0xb41) [0x5608d36e68f1]
#62 in /opt/conda/bin/python(_PyFunction_FastCallKeywords+0x184) [0x5608d36711d4]
#63 in /opt/conda/bin/python(+0x13b0d8) [0x5608d36740d8]

Taking N < 3 works perfectly.

Expected behavior Running the same code using sklearn.neighbors algorithms return expected results.

from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier, KNeighborsRegressor
import pandas as pd
d = {
    'id':[1,2,3,4,5,6],
    'latitude':[50,-22,13,37,43,14],
    'longitude':[3,-43,100,27,-4,121],
}
df = pd.DataFrame(d) # total samples = 6
coo_cols = ["latitude", "longitude"]

N = 3

matcher_nn = NearestNeighbors(n_neighbors=N)
matcher_nn.fit(df[coo_cols])
distances_nn, indices_nn = matcher_nn.kneighbors(df[coo_cols])

matcher_knc = KNeighborsClassifier(n_neighbors=N)
matcher_knc.fit(df[coo_cols], df.index)
distances_knc, indices_knc = matcher_nn.kneighbors(df[coo_cols])

matcher_knr = KNeighborsRegressor(n_neighbors=N)
matcher_knr.fit(df[coo_cols], df.index)
distances_knr, indices_knr = matcher_nn.kneighbors(df[coo_cols])

distances_nn==distances_knr

returns:

array([[ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True]])

Environment details (please complete the following information):

Additional context Add any other context about the problem here.

cjnolet commented 2 years ago

@pj-mathematician This is indeed a bug in the algorithm selection, since the rbc algorithm should not be used when k >= sqrt(n), we should be defaulting to brute force but it's not currently doing that. In the meantime, you should be able to use the argument algorithm='brute' when constructing the NearestNeighbors estimator to get around this bug.

github-actions[bot] commented 2 years ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

beckernick commented 2 years ago

Handling this edge case of algorithm selection is probably a good first issue

cjnolet commented 2 years ago

@beckernick i could be wrong but I believe this issue might have actually been fixed in a more recent PR (it now will select brute force in this case or throw an error if rbc was selected explicitly)

cjnolet commented 2 years ago

Verified this has indeed been fixed here. Closing.