xgcm / xhistogram

Fast, flexible, label-aware histograms for numpy and xarray
https://xhistogram.readthedocs.io
MIT License
90 stars 20 forks source link

Refactor histogram to use blockwise #49

Closed rabernat closed 3 years ago

rabernat commented 3 years ago

As discussed in https://github.com/ocean-transport/coiled_collaboration/issues/8, I feel like there must be a more efficient implementation of xhistogram's core logic in terms of handling dask arrays. A central problem with our implementation is the use of reshape on dask arrays. This causes inefficient chunk manipulations (see https://github.com/dask/dask/issues/5544).

We should be able to just apply the bin_count function to every block and then sum over all blocks to get the total bin count. If we go this route, we will no longer use dask.array.histogram at all. This should result in much simpler dask graphs.

I have not actual figured out how to solve the problem, but this PR gets started in that direction. The challenge is in determining the arguments to pass to map_blocks.

dcherian commented 3 years ago

:+1: You know the number of bins or groups so you can "tree reduce" it without any reshaping as you say. The current version of dask's bincount does exactly that for the 1d case: https://github.com/dask/dask/blob/98a65849a6ab2197ac9e24838d1effe701a441dc/dask/array/routines.py#L677-L698

It's what inspired the core of my groupby experiment: https://github.com/dcherian/dask_groupby/blob/373a4dda4d618848b5d02cca551893fee36df092/dask_groupby/core.py#L361-L368 which only reshapes if you know absolutely nothing about the "expected" groups here: https://github.com/dcherian/dask_groupby/blob/373a4dda4d618848b5d02cca551893fee36df092/dask_groupby/core.py#L604-L608

IIUC your proposal is implemented in tensordot as a blockwise-which-adds-dummy-axis followed by sum. https://github.com/dask/dask/blob/98a65849a6ab2197ac9e24838d1effe701a441dc/dask/array/routines.py#L295-L310

I found blockwise a lot easier to handle than map_blocks for this stuff

jrbourbeau commented 3 years ago

cc @gjoseph92

rabernat commented 3 years ago

@dcherian - at this point it is clear that you know a lot more about this than me. Would you have any interest in taking a crack at implementing the approach you just described? I think this PR sets it up nicely so all you have to do is plug in the tree reduce code.

gjoseph92 commented 3 years ago

Okay, I think this is the skeleton for how we could implement it: https://gist.github.com/gjoseph92/1ceaf28937f485012bc1540eb2e8da7d

We just need the actual bincount-ish function to be applied over the chunk pairs. And to handle weights, and the bins being dask arrays. But I think those are "just" a matter of broadcasting and lining up the dimension names to what I have here.

Now that I'm seeing how much cleaner the graphs can be in the n-dimensional chunking case (I think the examples I'd looked at earlier all had one chunk along the aggregated dimensions), I fully agree with @rabernat that blockwise could improve the graph structure a lot.

dcherian commented 3 years ago

:+1: That's exactly it AFAICT. It's the same trick as tensordot and isin.

rabernat commented 3 years ago

@gjoseph92 that looks amazing. Exactly what I had in mind originally with this PR but didn't have the dask chops to actually implement. Let me know how I can help move this forward.

TomNicholas commented 3 years ago

I'm happy to have a go at implementing this too. I would eventually like to integrate it upstream in xarray, but only once we've got the best dask function for the job working.

One question - what would the numpy equivalent of @gjoseph92 's solution be? Is it just to apply bincount once over the whole array without calling blockwise?

dougiesquire commented 3 years ago

I'm happy to have a go at implementing this too. I would eventually like to integrate it upstream in xarray, but only once we've got the best dask function for the job working.

@TomNicholas, for someone that's not up with the lingo, do you mean move histogram functionality into xarray?

One question - what would the numpy equivalent of @gjoseph92 's solution be? Is it just to apply bincount once over the whole array without calling blockwise?

For 1D arrays at least, I think numpy histogram actually loops over blocks and cumulatively sums the bincount applied to each block, with a note that this is both faster and more memory efficient for large arrays.

TomNicholas commented 3 years ago

do you mean move histogram functionality into xarray

Yes exactly. It's arguably within the scope of xarray (which already has a .hist method, it just currently reduces over all dimensions), and has functions for groupby operations and similar types of reshaping already.

In the long term the advantages would be:

I think numpy histogram actually loops over blocks

Oh that's interesting, thanks!

dougiesquire commented 3 years ago

Yes exactly. It's arguably within the scope of xarray

I think this is a terrific idea. Happy to help out however I can

TomNicholas commented 3 years ago

I think this is a terrific idea. Happy to help out however I can

Great! Well anyone can do the PR(s) to xarray - I would be happy to provide guidance & code reviews over there if you wanted to try it?

The steps as I see it are:

dougiesquire commented 3 years ago

Yup I think it's a good idea to get the dask.blockwise implementation working well here as a first step :thumbsup:

TomNicholas commented 3 years ago

So looking at the code in xhistogram.core, I have some questions:

To implement @gjoseph92 's solution (which looks great) then we would need to write a new bincount function that matches the signature described in his notebook. The _bincount function we currently have has a different signature (though I'm not sure exactly what signature it has because it's an internal function so doesn't have a docstring), so would need to be significantly changed. Then _bincount calls these other functions like _bincount_2d_vectorized, which I think is specific to the reshape dask algorithm we are currently using? If we are changing all these functions are we essentially just rewriting the whole of xhistogram.core? I'm not saying we shouldn't do that, I'm just trying to better understand what's required. Can any of this be copied over from other implementations (e.g. dask.array.histogramdd)?

gjoseph92 commented 3 years ago

If we are changing all these functions are we essentially just rewriting the whole of xhistogram.core

Yes, I think we sort of are. xhistogram's current approach is dask-agnostic (ish, besides _dispatch_bincount)—it uses a bunch of NumPy functions and lets dask parallelize them implicitly. With the blockwise approach, we'd sort of have a shared "inner kernel" (the core bincount operation), then slightly different code paths for mapping that over a NumPy array versus explicitly parallelizing it over a dask array.

The difference is very small: for a NumPy array, apply the core bincount function directly to the array. For a dask array, use blockwise to apply it. (You can think of the NumPy case as just a 1-chunk dask array.) Then in both cases, finish with bincounts.sum(axis=axis).

So in a way, the bincount "inner kernel" is all of the current xhistogram core, just simplified to not have to worry about dask arrays, and to return a slightly different shape. I think this change might actually make xhistogram a bit easier to read/reason about. Additionally, I think we should consider writing this bincount part in numba or Cython eventually—not just because it's faster, but because being able to just write a for loop would be way easier to read than the current reshape and ravel_multi_index logic, which is hard to wrap your head around.

Also, it might be worth passing NumPy arrays of sufficient size (or any size) into dask.array.from_array, putting them through the dask logic, and computing the result with scheduler="threads". This would be a quick and easy way to get the loop-over-blocks behavior @dougiesquire was describing, (maybe) a lower peak memory footprint, plus some parallelism.

TomNicholas commented 3 years ago

That's very helpful, thanks @gjoseph92 .

I think this change might actually make xhistogram a bit easier to read/reason about.

I agree.

Additionally, I think we should consider writing this bincount part in numba or Cython eventually

That would be cool, but it would also affect whether or not we could integrate the functionality upstream in xarray, so I suggest that we get a version working with just dask initially so we can see if it's fast "enough"?

computing the result with scheduler="threads"

That might even be like a "best of both" solution, ish?

dougiesquire commented 3 years ago

Thanks @gjoseph92 for that very helpful description!

@TomNicholas I'm going on leave next week and I'm hoping to work on this a little during that. My plan is to spend some time understanding/implementing the blockwise approach and then to open another PR once things are somewhat functional. But I want to make sure I don't tread on your toes or duplicate your efforts. What's the best way to align our efforts?

TomNicholas commented 3 years ago

I'm hoping to work on this a little during that

Great!

But I want to make sure I don't tread on your toes or duplicate your efforts.

@dougiesquire I wouldn't worry about that too much. I think we will both want to understand the resultant code, so time spent understanding it is always well-spent. For development we should just both post whatever we try out - as a gist, a PR, or could even start a new dev branch on this repo.

codecov[bot] commented 3 years ago

Codecov Report

Merging #49 (d1df004) into master (daf41f0) will decrease coverage by 11.65%. The diff coverage is 93.75%.

Impacted file tree graph

@@             Coverage Diff             @@
##           master      #49       +/-   ##
===========================================
- Coverage   95.79%   84.14%   -11.66%     
===========================================
  Files           3        2        -1     
  Lines         238      246        +8     
  Branches       69       77        +8     
===========================================
- Hits          228      207       -21     
- Misses          7       34       +27     
- Partials        3        5        +2     
Impacted Files Coverage Δ
xhistogram/core.py 81.81% <93.75%> (-18.19%) :arrow_down:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update daf41f0...d1df004. Read the comment docs.

rabernat commented 3 years ago

In my chat today with @TomNicholas, we realized this PR was already quite close to working. I just needed to understand how to use blockwise (rather than map_blocks). Tests are all green. I have not examined the dask graphs yet, but I think they will be improved.

Regardless of whether this works, Tom raised the point that the current implementation is quite opaque. There are many nested layers of internal functions that make the logic hard to follow. So it might not be a bad idea to just refactor everything anyway.

rabernat commented 3 years ago

@TomNicholas - today on our call you mentioned that this PR doesn't actually work with Dask. 🤣 That was absolutely true of 9f20e95. However, I think I have updated things now so that they do actually work.

Here is an example.

import numpy as np
import dask.array as dsa
import xhistogram.core as xhc

csize = 80
nz = 20
ntime = 2
fac = 4

bins = np.arange(10)

data = dsa.ones((nz, fac * csize, fac * csize),
                chunks=(nz, csize, csize))
display(data)
hist = xhc.histogram(data, bins=[bins], axis=[0, 1])
hist_eager = xhc.histogram(data.compute(), bins=[bins], axis=[0, 1])
display(hist)
display(hist.visualize())
np.testing.assert_allclose(hist.compute(), hist_eager)

image

image

Regardless, the fact that 9f20e95 did not raise any test failures reveals a major gap in our test coverage. You mentioned you were working on expanding test coverage of dask cases. It would be great to do that in a separate PR, merge it, and then rebase this on top of it to see whether the implementation here passes the new tests.

TomNicholas commented 3 years ago

now they do actually work

@rabernat nice! Thanks for fixing that. Also that task graph looks beautiful! Is it still as neat for a multi-variable histogram?

I was really struggling to work out what was the problem with your PR, and I still don't really understand how you fixed it - IMO that means the code should still be refactored to be clearer, which I will have a go at.

You mentioned you were working on expanding test coverage of dask cases

Yes it's pretty surprising that broken code still passed the tests... but I've added dask tests over in #57 now.

TomNicholas commented 3 years ago

Is it still as neat for a multi-variable histogram?

For anyone wondering, yes it is (for same but with two data variables):

image

jbusecke commented 3 years ago

For anyone wondering, yes it is (for same but with two data variables)

Amazing! Is this also including weights?

TomNicholas commented 3 years ago

With weights you get this:

csize = 80
nz = 20
ntime = 2
fac = 4

bins1 = np.arange(10)
bins2 = np.arange(8)

data1 = dsa.random.random((nz, fac * csize, fac * csize),
                          chunks=(nz, csize, csize))
data2 = dsa.random.random((nz, fac * csize, fac * csize),
                          chunks=(nz, csize, csize))
weights = dsa.random.random((nz, fac * csize, fac * csize),
                          chunks=(nz, csize, csize))

display(data)
hist = xhc.histogram(data1, data2, bins=[bins1, bins2], weights=weights, axis=[0, 1])
hist_eager = xhc.histogram(data1.compute(), data2.compute(), bins=[bins1, bins2], weights=weights.compute(), axis=[0, 1])
display(hist)
display(hist.visualize())
np.testing.assert_allclose(hist.compute(), hist_eager)

image

jbusecke commented 3 years ago

Just beautiful! I suspect this will help tremendously over at https://github.com/ocean-transport/coiled_collaboration/issues/9.

Do you think this is ready to test under real-world conditions?

TomNicholas commented 3 years ago

Do you think this is ready to test under real-world conditions?

You could try it, but probably not, because chunking has no explicit tests, apart from the ones I added in #57, which do not all yet pass :confused:

TomNicholas commented 3 years ago

@jbusecke scratch that - it looks like it was my tests that were the problem after all - should be green to try this out now!

jbusecke commented 3 years ago

Whoa! This works really well!

Dask stream from a synthetic example, which would fail often (https://github.com/ocean-transport/coiled_collaboration/issues/9). This example uses two variables and weights.

image

I ran this on the google deployment with the 'hacky pip install' described here. Will try it out for realistic CMIP6 data next.

I can't thank you enough for implementing this! I know who I will thank at the next CDSLab meeting!

jbusecke commented 3 years ago

Ok so after some tuning of the chunksize I got some pretty darn satisfying results for the real world example:

image

This crunched through a single member in under 2 minutes with 20 workers. Absolutely no spilling!

TomNicholas commented 3 years ago

So now that we've successfully chunked over all dims, do we still need any of the "block size" logic? The main benefit of being able to specify block_size is a lowered memory footprint right? But if we can now chunk over any dim then the dask chunks can always be made smaller for a given problem, and the user can always lower their memory footprint by using smaller chunks instead?

Or do we still want it for the same reasons that numpy.histogram has it?

(There is also no block used anywhere in np.histogramdd)

rabernat commented 3 years ago

Or do we still want it for the same reasons that numpy.histogram has it?

My understanding was that this is a low-level optimization, aimed at making the core bincount routine faster. Dask is responsible for scaling out, but having a faster low-level routine is always desireable as well. That said, I have not benchmarked anything to verify it actually makes a difference.

rabernat commented 3 years ago

Tom, the only thing that failed after I accepted your suggestion was the linting. So please go ahead and merge this if you think it's ready! Lots of stuff is blocking on this PR.

jbusecke commented 3 years ago

🚀