xarray-contrib / xoak

xarray extension that provides tree-based indexes used for selecting irregular, n-dimensional data.
https://xoak.readthedocs.io
MIT License
57 stars 4 forks source link

Re-implement scikit-learn's search trees with numba #9

Open benbovy opened 4 years ago

benbovy commented 4 years ago

This could be done at a later stage, if we choose to go down this way.

The implementation approach used in scikit-learn is interesting in several aspects:

I think numba is now mature enough and supported in various distribution so that we can use it as a dependency. I'm not sure if numba's jitted classes are very mature and/or we could avoid using it here, though.

The biggest advantage of using numba is just-in-time compilation that allows very flexible metric functions.

Huite commented 4 years ago

Hey, since you mentioned xoak in: https://github.com/NOAA-ORR-ERD/gridded/issues/55

I did some looking around before, and I came across this repository: https://github.com/jackd/numba-neighbors

(With MIT license, so good to go)

Looks almost a perfect match with what you're proposing here?

It uses a jitclass, but very lightly, which is arguably the right approach in my opinion. You could pass the tree data more easily as a namedtuple, if you don't want to pass all the arguments.

Some query methods are still missing, but not that difficult to implement; although I'm not sure you can dynamically allocate as efficiently? (Numba could use something like C++'s std::Vector -- or is a typed List this already, it felt significantly slower to me.)

Also parallelisation is extremely simple using numba's prange.

JIT indeed provides very flexible metric functions, best way to introduce seems by using closures in numba to avoid the function call overhead, I believe: https://numba.pydata.org/numba-doc/latest/user/faq.html#can-i-pass-a-function-as-an-argument-to-a-jitted-function

I've also noticed that performance can benefit significantly by aggressively inlining (although this increases compile cost). Since a tree will generally consist of float32 or float64 coordinates, and int32 or int64 indices, maybe it's a nice idea to ahead-of-time compile for the built-in metric functions. https://numba.pydata.org/numba-doc/dev/user/pycc.html#compiling-code-ahead-of-time

benbovy commented 4 years ago

Good to know about the numba-neighbors repository and numba tricks @Huite, thanks!