Closed rabernat closed 3 years ago
:+1: You know the number of bins or groups so you can "tree reduce" it without any reshaping as you say. The current version of dask's bincount does exactly that for the 1d case: https://github.com/dask/dask/blob/98a65849a6ab2197ac9e24838d1effe701a441dc/dask/array/routines.py#L677-L698
It's what inspired the core of my groupby experiment: https://github.com/dcherian/dask_groupby/blob/373a4dda4d618848b5d02cca551893fee36df092/dask_groupby/core.py#L361-L368 which only reshapes if you know absolutely nothing about the "expected" groups here: https://github.com/dcherian/dask_groupby/blob/373a4dda4d618848b5d02cca551893fee36df092/dask_groupby/core.py#L604-L608
IIUC your proposal is implemented in tensordot as a blockwise-which-adds-dummy-axis followed by sum. https://github.com/dask/dask/blob/98a65849a6ab2197ac9e24838d1effe701a441dc/dask/array/routines.py#L295-L310
I found blockwise a lot easier to handle than map_blocks for this stuff
cc @gjoseph92
@dcherian - at this point it is clear that you know a lot more about this than me. Would you have any interest in taking a crack at implementing the approach you just described? I think this PR sets it up nicely so all you have to do is plug in the tree reduce code.
Okay, I think this is the skeleton for how we could implement it: https://gist.github.com/gjoseph92/1ceaf28937f485012bc1540eb2e8da7d
We just need the actual bincount
-ish function to be applied over the chunk pairs. And to handle weights, and the bins being dask arrays. But I think those are "just" a matter of broadcasting and lining up the dimension names to what I have here.
Now that I'm seeing how much cleaner the graphs can be in the n-dimensional chunking case (I think the examples I'd looked at earlier all had one chunk along the aggregated dimensions), I fully agree with @rabernat that blockwise could improve the graph structure a lot.
:+1: That's exactly it AFAICT. It's the same trick as tensordot and isin.
@gjoseph92 that looks amazing. Exactly what I had in mind originally with this PR but didn't have the dask chops to actually implement. Let me know how I can help move this forward.
I'm happy to have a go at implementing this too. I would eventually like to integrate it upstream in xarray, but only once we've got the best dask function for the job working.
One question - what would the numpy equivalent of @gjoseph92 's solution be? Is it just to apply bincount
once over the whole array without calling blockwise
?
I'm happy to have a go at implementing this too. I would eventually like to integrate it upstream in xarray, but only once we've got the best dask function for the job working.
@TomNicholas, for someone that's not up with the lingo, do you mean move histogram functionality into xarray?
One question - what would the numpy equivalent of @gjoseph92 's solution be? Is it just to apply
bincount
once over the whole array without callingblockwise
?
For 1D arrays at least, I think numpy histogram actually loops over blocks and cumulatively sums the bincount
applied to each block, with a note that this is both faster and more memory efficient for large arrays.
do you mean move histogram functionality into xarray
Yes exactly. It's arguably within the scope of xarray (which already has a .hist
method, it just currently reduces over all dimensions), and has functions for groupby
operations and similar types of reshaping already.
In the long term the advantages would be:
I think numpy histogram actually loops over blocks
Oh that's interesting, thanks!
Yes exactly. It's arguably within the scope of xarray
I think this is a terrific idea. Happy to help out however I can
I think this is a terrific idea. Happy to help out however I can
Great! Well anyone can do the PR(s) to xarray - I would be happy to provide guidance & code reviews over there if you wanted to try it?
The steps as I see it are:
Yup I think it's a good idea to get the dask.blockwise implementation working well here as a first step :thumbsup:
So looking at the code in xhistogram.core
, I have some questions:
To implement @gjoseph92 's solution (which looks great) then we would need to write a new bincount
function that matches the signature described in his notebook. The _bincount
function we currently have has a different signature (though I'm not sure exactly what signature it has because it's an internal function so doesn't have a docstring), so would need to be significantly changed. Then _bincount
calls these other functions like _bincount_2d_vectorized
, which I think is specific to the reshape
dask algorithm we are currently using? If we are changing all these functions are we essentially just rewriting the whole of xhistogram.core
? I'm not saying we shouldn't do that, I'm just trying to better understand what's required. Can any of this be copied over from other implementations (e.g. dask.array.histogramdd)?
If we are changing all these functions are we essentially just rewriting the whole of
xhistogram.core
Yes, I think we sort of are. xhistogram's current approach is dask-agnostic (ish, besides _dispatch_bincount
)—it uses a bunch of NumPy functions and lets dask parallelize them implicitly. With the blockwise
approach, we'd sort of have a shared "inner kernel" (the core bincount operation), then slightly different code paths for mapping that over a NumPy array versus explicitly parallelizing it over a dask array.
The difference is very small: for a NumPy array, apply the core bincount function directly to the array. For a dask array, use blockwise
to apply it. (You can think of the NumPy case as just a 1-chunk dask array.) Then in both cases, finish with bincounts.sum(axis=axis)
.
So in a way, the bincount "inner kernel" is all of the current xhistogram core, just simplified to not have to worry about dask arrays, and to return a slightly different shape. I think this change might actually make xhistogram a bit easier to read/reason about. Additionally, I think we should consider writing this bincount part in numba or Cython eventually—not just because it's faster, but because being able to just write a for
loop would be way easier to read than the current reshape
and ravel_multi_index
logic, which is hard to wrap your head around.
Also, it might be worth passing NumPy arrays of sufficient size (or any size) into dask.array.from_array
, putting them through the dask logic, and computing the result with scheduler="threads"
. This would be a quick and easy way to get the loop-over-blocks behavior @dougiesquire was describing, (maybe) a lower peak memory footprint, plus some parallelism.
That's very helpful, thanks @gjoseph92 .
I think this change might actually make xhistogram a bit easier to read/reason about.
I agree.
Additionally, I think we should consider writing this bincount part in numba or Cython eventually
That would be cool, but it would also affect whether or not we could integrate the functionality upstream in xarray, so I suggest that we get a version working with just dask initially so we can see if it's fast "enough"?
computing the result with scheduler="threads"
That might even be like a "best of both" solution, ish?
Thanks @gjoseph92 for that very helpful description!
@TomNicholas I'm going on leave next week and I'm hoping to work on this a little during that. My plan is to spend some time understanding/implementing the blockwise
approach and then to open another PR once things are somewhat functional. But I want to make sure I don't tread on your toes or duplicate your efforts. What's the best way to align our efforts?
I'm hoping to work on this a little during that
Great!
But I want to make sure I don't tread on your toes or duplicate your efforts.
@dougiesquire I wouldn't worry about that too much. I think we will both want to understand the resultant code, so time spent understanding it is always well-spent. For development we should just both post whatever we try out - as a gist, a PR, or could even start a new dev branch on this repo.
Merging #49 (d1df004) into master (daf41f0) will decrease coverage by
11.65%
. The diff coverage is93.75%
.
@@ Coverage Diff @@
## master #49 +/- ##
===========================================
- Coverage 95.79% 84.14% -11.66%
===========================================
Files 3 2 -1
Lines 238 246 +8
Branches 69 77 +8
===========================================
- Hits 228 207 -21
- Misses 7 34 +27
- Partials 3 5 +2
Impacted Files | Coverage Δ | |
---|---|---|
xhistogram/core.py | 81.81% <93.75%> (-18.19%) |
:arrow_down: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update daf41f0...d1df004. Read the comment docs.
In my chat today with @TomNicholas, we realized this PR was already quite close to working. I just needed to understand how to use blockwise (rather than map_blocks). Tests are all green. I have not examined the dask graphs yet, but I think they will be improved.
Regardless of whether this works, Tom raised the point that the current implementation is quite opaque. There are many nested layers of internal functions that make the logic hard to follow. So it might not be a bad idea to just refactor everything anyway.
@TomNicholas - today on our call you mentioned that this PR doesn't actually work with Dask. 🤣 That was absolutely true of 9f20e95. However, I think I have updated things now so that they do actually work.
Here is an example.
import numpy as np
import dask.array as dsa
import xhistogram.core as xhc
csize = 80
nz = 20
ntime = 2
fac = 4
bins = np.arange(10)
data = dsa.ones((nz, fac * csize, fac * csize),
chunks=(nz, csize, csize))
display(data)
hist = xhc.histogram(data, bins=[bins], axis=[0, 1])
hist_eager = xhc.histogram(data.compute(), bins=[bins], axis=[0, 1])
display(hist)
display(hist.visualize())
np.testing.assert_allclose(hist.compute(), hist_eager)
Regardless, the fact that 9f20e95 did not raise any test failures reveals a major gap in our test coverage. You mentioned you were working on expanding test coverage of dask cases. It would be great to do that in a separate PR, merge it, and then rebase this on top of it to see whether the implementation here passes the new tests.
now they do actually work
@rabernat nice! Thanks for fixing that. Also that task graph looks beautiful! Is it still as neat for a multi-variable histogram?
I was really struggling to work out what was the problem with your PR, and I still don't really understand how you fixed it - IMO that means the code should still be refactored to be clearer, which I will have a go at.
You mentioned you were working on expanding test coverage of dask cases
Yes it's pretty surprising that broken code still passed the tests... but I've added dask tests over in #57 now.
Is it still as neat for a multi-variable histogram?
For anyone wondering, yes it is (for same but with two data variables):
For anyone wondering, yes it is (for same but with two data variables)
Amazing! Is this also including weights?
With weights you get this:
csize = 80
nz = 20
ntime = 2
fac = 4
bins1 = np.arange(10)
bins2 = np.arange(8)
data1 = dsa.random.random((nz, fac * csize, fac * csize),
chunks=(nz, csize, csize))
data2 = dsa.random.random((nz, fac * csize, fac * csize),
chunks=(nz, csize, csize))
weights = dsa.random.random((nz, fac * csize, fac * csize),
chunks=(nz, csize, csize))
display(data)
hist = xhc.histogram(data1, data2, bins=[bins1, bins2], weights=weights, axis=[0, 1])
hist_eager = xhc.histogram(data1.compute(), data2.compute(), bins=[bins1, bins2], weights=weights.compute(), axis=[0, 1])
display(hist)
display(hist.visualize())
np.testing.assert_allclose(hist.compute(), hist_eager)
Just beautiful! I suspect this will help tremendously over at https://github.com/ocean-transport/coiled_collaboration/issues/9.
Do you think this is ready to test under real-world conditions?
Do you think this is ready to test under real-world conditions?
You could try it, but probably not, because chunking has no explicit tests, apart from the ones I added in #57, which do not all yet pass :confused:
@jbusecke scratch that - it looks like it was my tests that were the problem after all - should be green to try this out now!
Whoa! This works really well!
Dask stream from a synthetic example, which would fail often (https://github.com/ocean-transport/coiled_collaboration/issues/9). This example uses two variables and weights.
I ran this on the google deployment with the 'hacky pip install' described here. Will try it out for realistic CMIP6 data next.
I can't thank you enough for implementing this! I know who I will thank at the next CDSLab meeting!
Ok so after some tuning of the chunksize I got some pretty darn satisfying results for the real world example:
This crunched through a single member in under 2 minutes with 20 workers. Absolutely no spilling!
So now that we've successfully chunked over all dims, do we still need any of the "block size" logic? The main benefit of being able to specify block_size
is a lowered memory footprint right? But if we can now chunk over any dim then the dask chunks can always be made smaller for a given problem, and the user can always lower their memory footprint by using smaller chunks instead?
Or do we still want it for the same reasons that numpy.histogram has it?
(There is also no block used anywhere in np.histogramdd
)
Or do we still want it for the same reasons that numpy.histogram has it?
My understanding was that this is a low-level optimization, aimed at making the core bincount routine faster. Dask is responsible for scaling out, but having a faster low-level routine is always desireable as well. That said, I have not benchmarked anything to verify it actually makes a difference.
Tom, the only thing that failed after I accepted your suggestion was the linting. So please go ahead and merge this if you think it's ready! Lots of stuff is blocking on this PR.
🚀
As discussed in https://github.com/ocean-transport/coiled_collaboration/issues/8, I feel like there must be a more efficient implementation of xhistogram's core logic in terms of handling dask arrays. A central problem with our implementation is the use of reshape on dask arrays. This causes inefficient chunk manipulations (see https://github.com/dask/dask/issues/5544).
We should be able to just apply the bin_count function to every block and then sum over all blocks to get the total bin count. If we go this route, we will no longer use dask.array.histogram at all. This should result in much simpler dask graphs.
I have not actual figured out how to solve the problem, but this PR gets started in that direction. The challenge is in determining the arguments to pass to
map_blocks
.