xarray-contrib / xoak

xarray extension that provides tree-based indexes used for selecting irregular, n-dimensional data.
https://xoak.readthedocs.io
MIT License
57 stars 4 forks source link

Refactor accessor + flexible indexes + Dask support #18

Closed benbovy closed 3 years ago

benbovy commented 4 years ago

This PR will eventually be quite big (sorry, at this stage I think it'll be more productive overall than splitting things in many PRs).

Two goals:

TODO:

Flexible indexes

Registering a new custom index in xoak can be easily done with the help of a small adapter class that must implement the build and query methods, e.g.,

import xoak
from mypackage import MyIndex

@xoak.register_index('my_index')
class MyIndexAdapter(xoak.IndexAdapter):

    def build(self, points):
        # must return an instance of the wrapped index
        return MyIndex(points)

    def query(self, my_index, points):
        # must return a (distances, positions) tuple of numpy arrays
        return my_index.query(points)

Any option to pass to the underlying index construction should be added as argument in the adapter class' __init__ method (we could address later how query options are handled).

In the example above, my_index is registered in xoak's available indexes. It can be selected like this:

Dataset.xoak.set_index(['lat', 'lon'], 'my_index')

It's also possible to directly provide the adapter class (useful when one does not want to register an index adapter):

Dataset.xoak.set_index(['lat', 'lon'], MyIndexAdapter)

xoak.indexes returns a mapping of all available indexes. As an alternative to the decorator above, it can also be used to register index adapters, e.g.,

xoak.indexes.register('my_index', MyIndexAdapter)

Dask support

This PR implements dask-enabled index build and/or query that is independent of the underlying index. It handles chunked coordinate variables for either or both index / query points.

For chunked index coordinates, a forest of index trees is built (one tree per chunk).

A query is executed in two stages:

Advantages of this approach:

Dataset.xoak.set_index(['lat', 'lon'], MyIndexAdapter, persist=True)

Potential caveats:

Other changes

benbovy commented 4 years ago

I think it's ready for review. @willirath @koldunovn if you want to take a look (I know, it's quite big).

I don't know why lint tests are failing here (I run black on those files). Tests are running fine locally.

We still need some more tests for the accessor. That can be done in #10 after merging this (probably a couple of merge conflicts to solve).

I updated the example notebook.

There's one performance issue when index coordinates are chunked (I think it's a xarray issue), but we can go ahead and address this later IMO.

benbovy commented 4 years ago

There's one performance issue when index coordinates are chunked (I think it's a xarray issue)

Opened https://github.com/pydata/xarray/issues/4555

benbovy commented 4 years ago

Let's also track https://github.com/pydata/xarray/issues/2511 so that eventually we won't have to trigger computation of the dask graph in .xoak.sel().

benbovy commented 4 years ago

I've done some tests on a bigger server with the S2 point index, still using randomly located points.

Settings:

Results:

Caveats:

The results above were obtained after I spent a while on tweaking chunks, dask cluster and dask configuration (e.g., disable work stealing). It turns out to be very tricky in reality. Scaling is also highly limited by memory.

There's one major issue here: during the query, the indexes get replicated on a growing number of dask workers, which causes memory blow-up and significant drop in performance due to index data serialization (which, in the case of the S2, means rebuilding indexes from scratch!). I suspect that the dask's scheduler has a poor idea on the memory footprint of those indexes (persisted in memory). This may explain why dask's heuristics to assign tasks to the workers miserably fails in this case.

I need to figure out:

benbovy commented 4 years ago

If we can give to dask a better idea on the memory footprint of each index

Turns out we just need to define the __sizeof__ method for the index wrapper classes. It does yield better performance, but it still doesn't solve all the issues. We could imagine to "fool" dask by returning a very large size and hope that indexes won't be replicated too much, but that's not very elegant.

if / how best we can force dask to submit the query tasks to the workers where the index lives in memory

It's pretty straightforward to do this using client.compute(..., workers=...). It doesn't play well with resilience, though (it may freeze the computation if workers die or indexes are deleted for some reason). To improve stability we could use client.replicate() and client.rebalance(). It works quite well, but all those manual tuning steps require experience and attention.


Despite those tricks, I still sometimes get workers restarted or indexes deleted for unknown reasons. I don't know, maybe I use too large amounts of data for the cluster I setup.

With this approach, I'm afraid there's no way to completely avoid large amounts of data being transferred between workers, so scalability is still limited.

On the positive side, we do really get very nice speed-ups with decent amounts of data (millions of points on my laptop)!

benbovy commented 4 years ago

Broadcasting the query point chunks to all workers may greatly help too:

query_points = ... # dask array

query_points = query_points.persist()

client.replicate(query_points)

This could be the best thing to do actually, since in most cases the whole query data should easily fit in each worker's memory.

Now this is looking good:

Screenshot 2020-11-05 at 09 13 38
benbovy commented 3 years ago

Ok let's merge this. Tests failed because of black (I suspect a version mismatch), so no harm here we can take care of this later.