Open mrocklin opened 4 months ago
For Dask Array people, here is a tiny test:
def test_out_of_order_slicing():
x = da.random.random((10, 20), chunks=(10, 10))
idx = np.random.randint(0, x.shape[1], x.shape[1])
y = x[:, idx]
assert y.npartitions < x.npartitions * 2
My sense is that we want to embrace an n^2
solution here when that reduces the task count. The benefit here is that here n
is really only in one dimension, and so will likely not be as large a blowup as we might otherwise expect. In the example at the top of this post it would blowup tasks by a factor of 10, but not by a factor of 400, which would be a considerable win. In some future we could also invoke the p2p machinery to do this better (but I'd be inclined to solve this with tasks short-term myself).
My sense is that we want to embrace an n^2 solution here when that reduces the task count. The benefit here is that here n is really only in one dimension, and so will likely not be as large a blowup as we might otherwise expect.
Reducing the number of tasks is easy if we just merged adjacent slices. However, in that scenario neither task number not network is the limiting factor but memory.
If we truly care/need this kind of random access take, the best move would probably be to have a p2p version of this as well.
we think that this has a lot of impact throughout xarray, but are not sure
Yes. groupby-map is quite popular (sadly).
Is this also the code path taken in sum/max/etc..? (assuming that flox is not around).
Yes. This is a high-impact problem to solve.
Ideally xarray would call something like map_blocks, rather than iterate through each group.
This assumes that you can shove a whole group into one block, which is not a good assumption.
I think the core problem is that we need Xarray to give dask/cubed/whatever all the information (indices that sort by group) in one call rather than a for-loop over groups.
Which suggests that the missing primitive is a shuffle_by_key
or shuffle(Array, indices: Sequence[Sequence[int], ...], axis: int) -> Sequence(Array, ...)
that yields one dask array per inner sequence of indices (note not necessarily one block).
That way a user is free to rechunk to a single block if it makes sense and dask/whatever is free to optimize the indexing in one go. This shuffle
would also be useful for flox' method="cohorts"
which loops over a groups-of-groups (https://github.com/xarray-contrib/flox/blob/22257d2ba8de666a94fc8954451905e9ebd361bb/flox/core.py#L1704-L1719).
For an example of the indices see
arr.groupby("time.dayofyear")._group_indices
[[0, 365, 730, 1096, 1461, 1826, 2191, 2557, 2922, 3287, 3652],
[1, 366, 731, 1097, 1462, 1827, 2192, 2558, 2923, 3288, 3653],
[2, 367, 732, 1098, 1463, 1828, 2193, 2559, 2924, 3289, 3654],
...
]
some future we could also invoke the p2p machinery to do this better
It feels like you've already solved a much harder problem at some level... it "just" needs to be hooked up to slicing.
This would be a special-case for dask-array. Is this ok?
This is fine. The iteration is hidden behind an abstraction already.
is anyone around to do this within xarray if we're also improving slicing on the dask array side?
Yes I can help, I've been refactoring all this code recently.
PS: Arkouda does something like this for groupby problems --- IIRC they always sort by key.
alternatively we need np.take_along_axis(array, axis, chunks=...)
. This would help with https://github.com/pydata/xarray/issues/4555 and binary ops with groupby (this is usually indexing with repeated ints).
missing primitive
alternatively we need
np.take_along_axis(array, axis, chunks=...)
Once array libraries define such a primitive, we now have a natural place for xarray to inteface with it - the ChunkManagerEntryPoint
(see https://docs.xarray.dev/en/stable/internals/chunked-arrays.html).
Here's an array-only implementation
from xarray.core.groupby import _inverse_permutation_indices
def center(x: np.ndarray):
return x - x.mean(axis=-1, keepdims=True)
def identity(x):
return x
func = identity
dask_array = arr["arr"].data
numpy_array = dask_array.compute()
indices = arr.groupby("time.dayofyear")._group_indices
group_sizes = [len(i) for i in indices]
# We could have Xarray do this shuffling when setting up
# the groupby problem only for chunked array types
shuffled = np.take(
dask_array,
indices=np.concatenate(indices),
axis=-1,
# TODO: this is the new argument
# chunks_hint = group_sizes,
)
# reset _group_indices for the shuffled array
summed = np.cumsum(group_sizes).tolist()
slices = [slice(None, summed[0])] + [
slice(summed[i], summed[i + 1]) for i in range(len(summed) - 1)
]
# This is the apply-combine step
applied = np.concatenate(
[func(shuffled[..., slicer]) for slicer in slices],
axis=-1,
)
# Now shuffle back, in case this was a groupby-transform, not a groupby-reduce
reordering = _inverse_permutation_indices(indices, N=dask_array.shape[-1])
result = np.take(
applied,
indices=reordering,
axis=-1,
# TODO: this is the new argument
# chunks_hint=dask_array.chunks,
)
np.testing.assert_equal(numpy_array, result)
The chunks_hint
should tell dask to always start a new chunk at those boundaries. The actual chunk sizes for the output. array can be treated as an implementation detail. That way you have at most one group per block in the output.
What's not clear to me is whether just optimizing the take
with the added chunks_hint
info is good enough for dask? (note i made in a mistake in mentioningtake_along_axis
earlier).
Or do we also need an API for the apply-combine step? If so, a natural API for groupby-apply on the shuffled array would be to mimic ufunc.reduceat
.
I think what you would want Is basically the shuffle (i.e. take) and then call map_blocks with the transformer / reducer. This is what DataFrames in Dask do as well.
I am a bit hesitant about
chunks_hint = group_sizes,
,
I'd rather have control about this in dask by default, creating a block per group might blow up the graph if we are unlucky with a high cardinality grouper. You can split the groups in map_blocks again if needed, hiding the cardinality of the grouper from the system
then call map_blocks with the transformer / reducer.
OK now I am finally understanding what you mean here. The UDF is expecting to receive only one group, we can still satisfy that interface by inserting a for-loop-over-groups in the map-blocks call. :+1:
chunks_hint = group_sizes
For one, without an extra arg, dask doesn't know where the group boundaries are. Another option is to pass:
# Insert a chunk boundary at one or more of these values...
chunks_hint = np.cumsum(group_sizes)
creating a block per group might blow up the graph if we are unlucky with a high cardinality grouper.
A low cardinality grouper is also important to consider, particularly with small chunk sizes. For example.
arr = xr.DataArray(
# time chunks of size-1 are common!
da.random.random((100, 100, len(time)), chunks=(100, 100, 1)),
dims=["lat", "lon", "time"],
coords={"lat": lat, "lon": lon, "time": time},
name="arr",
).to_dataset()
with xr.set_options(use_flox=False):
arr.groupby("time.year").mean().compute()
^ Today, this is a tree-reduction reduce because we dispatch to dask.array.mean
. The shuffle/map_blocks solution would be a regression.
^ Today, this is a tree-reduction reduce because we dispatch to dask.array.mean. The shuffle/map_blocks solution would be a regression.
I definitely agree that the shuffle map_blocks solution shouldn't be the default for reducers. A tree reduction is totally fine if you have a low cardinality grouper and will most likely outperform every shuffle implementation we can come up with. We have to come up with a threshold starting from where the shuffle map_blocks approach makes sense for reducers.
For one, without an extra arg, dask doesn't know where the group boundaries are. Another option is to pass:
Yeah we need some marker where the group boundaries are, I totally agree with you here. What I would want to avoid is that the group boundaries necessarily define the chunks and then you would do:
OK now I am finally understanding what you mean here. The UDF is expecting to receive only one group, we can still satisfy that interface by inserting a for-loop-over-groups in the map-blocks call. 👍
For context: in DataFrame land this works as follow: df.groupby(...).transform(...)
is translated into the following:
shuffled = df.shuffle(...)
result = shuffled.map_partitions(lambda x: x.groupby(...).transform(...))
We push the actual groupby operation into the map_partitions call. For arrays you would "re-do" the groupby in the map_blocks call to ensure that the UDF gets one group at a time, but we don't have to create a task per group with this approach, rather one per chunk which should keep the graph size smaller and still parallelises very well.
After thinking over it over the weekend, I think it'd be nice to have at least two APIs.
Array.shuffle(indices: list[list[int]], axis: int)
that does guarantee an integer number of groups N >= 1 in the output chunks (analogous to dask.dataframe). shuffle
should rechunk the other axes automatically to make this feasible. This would be immediately useful in GroupBy.quantile
and GroupBy.median
. I believe this is implemented in https://github.com/dask/dask/pull/11262np.take
that does not bother about groups, and simply extracts/reorders the elements in the most efficient manner possible. Here I think it is preferable to preserve the chunking along the other axes as much as possible. We might consider a chunks_hint
argument, that suggests optional chunksizes (or chunk boundaries) as proposed in https://github.com/pydata/xarray/issues/9220#issuecomment-2256689592. This is more general, and I think Xarray would call take
followed by slicing to index out each group for the general GroupBy.reduce
. cc @phofl
That sounds good to me.
So the current shuffle implementation doesn't do any rechunking, assuming you have a group that is very large you would expect us to change the length of the chunks across other dimensions correct?
re take: I have a pr https://github.com/dask/dask/pull/11267 that hooks the shuffle approach into take, this means that we preserve the chunk size along the axis that we are applying the take indexer but it's a lot more efficient than an out of order take indexer (this is what I benchmarked here). This is something that would be helpful if we expose this toplevel I assume?
assuming you have a group that is very large you would expect us to change the length of the chunks across other dimensions correct?
Yes you'd absolutely have to.
This is something that would be helpful if we expose this toplevel I assume?
Isn't it already exposed as take
?
I have a pr https://github.com/dask/dask/pull/11267
How does this do with repeated indices? https://github.com/pydata/xarray/issues/2745
Yes you'd absolutely have to.
This isn't yet implemented, I'll add this as a follow up.
Isn't it already exposed as take?
good point, this should end up in the same place (I'll double check that to be sure though)
How does this do with repeated indices? https://github.com/pydata/xarray/issues/2745
It places them in different chunks if you have too many of them. It looks at the average input chunk size along the dimension your are indexing and then uses this as the target chunk size for the output. It would preserve the 100, 100 chunks in the example in the other issue.
It places them in different chunks if you have too many of them.
Perfect! We'll need this for groupby-binary-ops (ds.groupby(...) - ds.groupby(...).mean()
)which has this total hack at the moment:
https://github.com/pydata/xarray/blob/1ac19c4231e3e649392add503ae5a13d8e968aef/xarray/core/groupby.py#L655-L668
Some updates after a chat with @phofl
First off, amazing progress on reducing memory usage in dask 2024.08.0!
TODO:
GroupBy.shuffle()
to allow users to manually shuffle the data so that all members of a group are guaranteed to be in one chunk. We may have more than one group in a single chunk.
GroupBy.median
and GroupBy.quantile
to make use of it.GroupBy.transform
that calls shuffle
and then calls xarray.map_blocks
with a wrapper function that applies another GroupBy with the UDF. This is similar to what dask.dataframe.GroupBy.map
does.
GroupBy.map
and GroupBy.reduce
use the new & improved dask.array.take
algorithm.
Array.vindex
which does its own thing.Array.vindex
through take
.take
algo, this should be a massive improvement.
What is your issue?
I think that this is a longstanding problem. Sorry if I missed an existing github issue.
I was looking at an Dask-array-backed Xarray workload with @phofl and we were both concerned about some performance we were seeing with groupby-aggregations called with out-of-order indices. Here is a minimal example:
Performance here is bad in a few ways:
We think that what is happening here looks like this:
For steps (1) and (3) above performance is bad in a way that we can reduce to a dask array performance issue. Here is a small reproducer for that:
We think that we can make this better on our end, and can take that away as homework.
However for step (2) we think that this probably has to be a change in xarray. Ideally xarray would call something like
map_blocks
, rather than iterate through each group. This would be a special-case for dask-array. Is this ok?Also, we think that this has a lot of impact throughout xarray, but are not sure. Is this also the code path taken in sum/max/etc..? (assuming that
flox
is not around). Mostly we're curious how much we all should prioritize this.Asks
Some questions: