Closed BrandonSmithJ closed 1 year ago
Seems more like numpy issue. This issue exists from numpy>=1.22.0
NumPy version: 1.21.6
Function completed in: 0.96 seconds
Function returned after 1.1 seconds
Function completed in: 9.97 seconds
Function returned after 11.4 seconds
NumPy version: 1.22.0
Function completed in: 0.66 seconds
Function returned after 7.7 seconds
Function completed in: 6.94 seconds
Function returned after 84.8 seconds
@BrandonSmithJ Could you provide the full output of sklearn.show_versions()
. I don't have such a huge overhead in recent numpy:
In [1]: from sklearn.neighbors import KDTree
...: import time
...:
...: def mwe(n):
...: start = time.time()
...: a = KDTree([[1]]).query_radius([[1]]*int(n), 1)
...: print(f'Function completed in: {time.time()-start:.2f} seconds')
...:
...: start = time.time()
...: mwe(1e6)
...: print(f'Function returned after {time.time()-start:.1f} seconds')
...:
...: start = time.time()
...: mwe(1e7)
...: print(f'Function returned after {time.time()-start:.1f} seconds')
Function completed in: 0.21 seconds
Function returned after 0.4 seconds
Function completed in: 2.25 seconds
Function returned after 3.7 seconds
System:
python: 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ]
executable: /Users/glemaitre/mambaforge/envs/dev/bin/python3.10
machine: macOS-13.4.1-arm64-arm-64bit
Python dependencies:
sklearn: 1.4.dev0
pip: 22.3.1
setuptools: 65.5.1
numpy: 1.25.0
scipy: 1.10.1
Cython: 0.29.33
pandas: 2.0.3
matplotlib: 3.7.1
joblib: 1.3.0.dev0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /Users/glemaitre/mambaforge/envs/dev/lib/libopenblas.0.dylib
version: 0.3.21
threading_layer: openmp
architecture: VORTEX
num_threads: 8
user_api: openmp
internal_api: openmp
prefix: libomp
filepath: /Users/glemaitre/mambaforge/envs/dev/lib/libomp.dylib
version: None
num_threads: 8
Sure thing, I've updated the original post with the full versions. I can also confirm that the issue disappears when using numpy==1.21.6 - I went ahead and posted the issue to numpy's repository. Thanks for narrowing it down!
Works as expected in WSL. The issue seems to be on Windows only.
NumPy version: 1.25.1
Function completed in: 0.69 seconds
Function returned after 0.8 seconds
Function completed in: 27.80 seconds
Function returned after 28.9 seconds
So I guess we can close this as not a scikit-learn issue anyway.
The way output is structured in nearest neighbor classes (e.g. KDTree) leads to an order of magnitude greater time required to garbage collect the output, compared to actually generating it. For example:
This results in the output:
Compare this to the output of an equivalent script which uses
scipy.spatial.KDTree
:Though the query operation is slower using scipy, the garbage collection time for the same output is inconsequential; this seems to be due to the nested objects being lists rather than arrays.
Also, I'm aware that by switching the build/query data, the problem goes away on this contrived example - but this results in a completely different output representation. In the data I'm actually working with, swapping the two leads to a runtime approaching the garbage collect time shown above, while still having a relatively large impact from garbage collection.