pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.64k stars 1.09k forks source link

Groupby-map is slow with out of order indices #9220

Open mrocklin opened 4 months ago

mrocklin commented 4 months ago

What is your issue?

I think that this is a longstanding problem. Sorry if I missed an existing github issue.

I was looking at an Dask-array-backed Xarray workload with @phofl and we were both concerned about some performance we were seeing with groupby-aggregations called with out-of-order indices. Here is a minimal example:

import xarray as xr
import dask.array as da
import numpy as np
import pandas as pd

lat = np.linspace(-89.5, 89.5, 100)
lon = np.linspace(-179.375, 179.375, 100)
time = pd.date_range(
    start="1990-01-01", end="2000-12-31", freq="D",
)

arr = (
    xr.DataArray(
        da.random.random((100, 100, len(time)), chunks=(100, 100, 365)),
        dims=["lat", "lon", "time"],
        coords={"lat": lat, "lon": lon, "time": time},
        name="arr"
    )
    .to_dataset()
)

arr["arr"].data
Screenshot 2024-07-09 at 11 58 07 AM
def f(x):
    return x

result = arr.groupby("time.dayofyear").map(f)
result["arr"].data
Screenshot 2024-07-09 at 11 58 34 AM

Performance here is bad in a few ways:

We think that what is happening here looks like this:

  1. slice underlying array with a very out-of-order array to arrange groups to be close to each other
  2. Iterate through each group and apply function
  3. slice the underlying array with the inverse array to put everything back in the right place

For steps (1) and (3) above performance is bad in a way that we can reduce to a dask array performance issue. Here is a small reproducer for that:

x = da.random.random((100, 100, 10000))
x
Screenshot 2024-07-09 at 11 51 52 AM
idx = np.random.randint(0, x.shape[2], x.shape[2])
x[:, :, idx]
Screenshot 2024-07-09 at 11 52 07 AM

We think that we can make this better on our end, and can take that away as homework.

However for step (2) we think that this probably has to be a change in xarray. Ideally xarray would call something like map_blocks, rather than iterate through each group. This would be a special-case for dask-array. Is this ok?

Also, we think that this has a lot of impact throughout xarray, but are not sure. Is this also the code path taken in sum/max/etc..? (assuming that flox is not around). Mostly we're curious how much we all should prioritize this.

Asks

Some questions:

mrocklin commented 4 months ago

For Dask Array people, here is a tiny test:

def test_out_of_order_slicing():
    x = da.random.random((10, 20), chunks=(10, 10))
    idx = np.random.randint(0, x.shape[1], x.shape[1])

    y = x[:, idx]
    assert y.npartitions < x.npartitions * 2

My sense is that we want to embrace an n^2 solution here when that reduces the task count. The benefit here is that here n is really only in one dimension, and so will likely not be as large a blowup as we might otherwise expect. In the example at the top of this post it would blowup tasks by a factor of 10, but not by a factor of 400, which would be a considerable win. In some future we could also invoke the p2p machinery to do this better (but I'd be inclined to solve this with tasks short-term myself).

fjetter commented 4 months ago

My sense is that we want to embrace an n^2 solution here when that reduces the task count. The benefit here is that here n is really only in one dimension, and so will likely not be as large a blowup as we might otherwise expect.

Reducing the number of tasks is easy if we just merged adjacent slices. However, in that scenario neither task number not network is the limiting factor but memory.

If we truly care/need this kind of random access take, the best move would probably be to have a p2p version of this as well.

dcherian commented 4 months ago

we think that this has a lot of impact throughout xarray, but are not sure

Yes. groupby-map is quite popular (sadly).

Is this also the code path taken in sum/max/etc..? (assuming that flox is not around).

Yes. This is a high-impact problem to solve.

Ideally xarray would call something like map_blocks, rather than iterate through each group.

This assumes that you can shove a whole group into one block, which is not a good assumption.

I think the core problem is that we need Xarray to give dask/cubed/whatever all the information (indices that sort by group) in one call rather than a for-loop over groups.

Which suggests that the missing primitive is a shuffle_by_key or shuffle(Array, indices: Sequence[Sequence[int], ...], axis: int) -> Sequence(Array, ...) that yields one dask array per inner sequence of indices (note not necessarily one block).

That way a user is free to rechunk to a single block if it makes sense and dask/whatever is free to optimize the indexing in one go. This shuffle would also be useful for flox' method="cohorts" which loops over a groups-of-groups (https://github.com/xarray-contrib/flox/blob/22257d2ba8de666a94fc8954451905e9ebd361bb/flox/core.py#L1704-L1719).

For an example of the indices see

arr.groupby("time.dayofyear")._group_indices
[[0, 365, 730, 1096, 1461, 1826, 2191, 2557, 2922, 3287, 3652],
 [1, 366, 731, 1097, 1462, 1827, 2192, 2558, 2923, 3288, 3653],
 [2, 367, 732, 1098, 1463, 1828, 2193, 2559, 2924, 3289, 3654],
...
]

some future we could also invoke the p2p machinery to do this better

It feels like you've already solved a much harder problem at some level... it "just" needs to be hooked up to slicing.

This would be a special-case for dask-array. Is this ok?

This is fine. The iteration is hidden behind an abstraction already.

is anyone around to do this within xarray if we're also improving slicing on the dask array side?

Yes I can help, I've been refactoring all this code recently.

PS: Arkouda does something like this for groupby problems --- IIRC they always sort by key.

dcherian commented 4 months ago

alternatively we need np.take_along_axis(array, axis, chunks=...) . This would help with https://github.com/pydata/xarray/issues/4555 and binary ops with groupby (this is usually indexing with repeated ints).

TomNicholas commented 4 months ago

missing primitive

alternatively we need np.take_along_axis(array, axis, chunks=...)

Once array libraries define such a primitive, we now have a natural place for xarray to inteface with it - the ChunkManagerEntryPoint (see https://docs.xarray.dev/en/stable/internals/chunked-arrays.html).

dcherian commented 4 months ago

Here's an array-only implementation

from xarray.core.groupby import _inverse_permutation_indices

def center(x: np.ndarray):
    return x - x.mean(axis=-1, keepdims=True)

def identity(x):
    return x

func = identity

dask_array = arr["arr"].data
numpy_array = dask_array.compute()
indices = arr.groupby("time.dayofyear")._group_indices
group_sizes = [len(i) for i in indices]

# We could have Xarray do this shuffling when setting up 
# the groupby problem only for chunked array types
shuffled = np.take(
    dask_array,
    indices=np.concatenate(indices),
    axis=-1,
    # TODO: this is the new argument
    # chunks_hint = group_sizes,
)

# reset _group_indices for the shuffled array
summed = np.cumsum(group_sizes).tolist()
slices = [slice(None, summed[0])] + [
    slice(summed[i], summed[i + 1]) for i in range(len(summed) - 1)
]

# This is the apply-combine step
applied = np.concatenate(
    [func(shuffled[..., slicer]) for slicer in slices], 
    axis=-1,
)

# Now shuffle back, in case this was a groupby-transform, not a groupby-reduce
reordering = _inverse_permutation_indices(indices, N=dask_array.shape[-1])
result = np.take(
    applied,
    indices=reordering,
    axis=-1,
    # TODO: this is the new argument
    # chunks_hint=dask_array.chunks,
)
np.testing.assert_equal(numpy_array, result)

The chunks_hint should tell dask to always start a new chunk at those boundaries. The actual chunk sizes for the output. array can be treated as an implementation detail. That way you have at most one group per block in the output.

What's not clear to me is whether just optimizing the take with the added chunks_hint info is good enough for dask? (note i made in a mistake in mentioningtake_along_axis earlier). Or do we also need an API for the apply-combine step? If so, a natural API for groupby-apply on the shuffled array would be to mimic ufunc.reduceat.

phofl commented 4 months ago

I think what you would want Is basically the shuffle (i.e. take) and then call map_blocks with the transformer / reducer. This is what DataFrames in Dask do as well.

I am a bit hesitant about

chunks_hint = group_sizes,,

I'd rather have control about this in dask by default, creating a block per group might blow up the graph if we are unlucky with a high cardinality grouper. You can split the groups in map_blocks again if needed, hiding the cardinality of the grouper from the system

dcherian commented 4 months ago

then call map_blocks with the transformer / reducer.

OK now I am finally understanding what you mean here. The UDF is expecting to receive only one group, we can still satisfy that interface by inserting a for-loop-over-groups in the map-blocks call. :+1:

chunks_hint = group_sizes

For one, without an extra arg, dask doesn't know where the group boundaries are. Another option is to pass:

# Insert a chunk boundary at one or more of these values...
chunks_hint = np.cumsum(group_sizes)

creating a block per group might blow up the graph if we are unlucky with a high cardinality grouper.

A low cardinality grouper is also important to consider, particularly with small chunk sizes. For example.

arr = xr.DataArray(
   # time chunks of size-1 are common!
    da.random.random((100, 100, len(time)), chunks=(100, 100, 1)),
    dims=["lat", "lon", "time"],
    coords={"lat": lat, "lon": lon, "time": time},
    name="arr",
).to_dataset()
with xr.set_options(use_flox=False):
    arr.groupby("time.year").mean().compute()

^ Today, this is a tree-reduction reduce because we dispatch to dask.array.mean. The shuffle/map_blocks solution would be a regression.

phofl commented 4 months ago

^ Today, this is a tree-reduction reduce because we dispatch to dask.array.mean. The shuffle/map_blocks solution would be a regression.

I definitely agree that the shuffle map_blocks solution shouldn't be the default for reducers. A tree reduction is totally fine if you have a low cardinality grouper and will most likely outperform every shuffle implementation we can come up with. We have to come up with a threshold starting from where the shuffle map_blocks approach makes sense for reducers.

For one, without an extra arg, dask doesn't know where the group boundaries are. Another option is to pass:

Yeah we need some marker where the group boundaries are, I totally agree with you here. What I would want to avoid is that the group boundaries necessarily define the chunks and then you would do:

OK now I am finally understanding what you mean here. The UDF is expecting to receive only one group, we can still satisfy that interface by inserting a for-loop-over-groups in the map-blocks call. 👍

For context: in DataFrame land this works as follow: df.groupby(...).transform(...) is translated into the following:

shuffled = df.shuffle(...)
result = shuffled.map_partitions(lambda x: x.groupby(...).transform(...))

We push the actual groupby operation into the map_partitions call. For arrays you would "re-do" the groupby in the map_blocks call to ensure that the UDF gets one group at a time, but we don't have to create a task per group with this approach, rather one per chunk which should keep the graph size smaller and still parallelises very well.

dcherian commented 3 months ago

After thinking over it over the weekend, I think it'd be nice to have at least two APIs.

  1. Array.shuffle(indices: list[list[int]], axis: int) that does guarantee an integer number of groups N >= 1 in the output chunks (analogous to dask.dataframe). shuffle should rechunk the other axes automatically to make this feasible. This would be immediately useful in GroupBy.quantile and GroupBy.median. I believe this is implemented in https://github.com/dask/dask/pull/11262
  2. The more general np.take that does not bother about groups, and simply extracts/reorders the elements in the most efficient manner possible. Here I think it is preferable to preserve the chunking along the other axes as much as possible. We might consider a chunks_hint argument, that suggests optional chunksizes (or chunk boundaries) as proposed in https://github.com/pydata/xarray/issues/9220#issuecomment-2256689592. This is more general, and I think Xarray would call take followed by slicing to index out each group for the general GroupBy.reduce.

cc @phofl

phofl commented 3 months ago

That sounds good to me.

So the current shuffle implementation doesn't do any rechunking, assuming you have a group that is very large you would expect us to change the length of the chunks across other dimensions correct?

re take: I have a pr https://github.com/dask/dask/pull/11267 that hooks the shuffle approach into take, this means that we preserve the chunk size along the axis that we are applying the take indexer but it's a lot more efficient than an out of order take indexer (this is what I benchmarked here). This is something that would be helpful if we expose this toplevel I assume?

dcherian commented 3 months ago

assuming you have a group that is very large you would expect us to change the length of the chunks across other dimensions correct?

Yes you'd absolutely have to.

This is something that would be helpful if we expose this toplevel I assume?

Isn't it already exposed as take?

I have a pr https://github.com/dask/dask/pull/11267

How does this do with repeated indices? https://github.com/pydata/xarray/issues/2745

phofl commented 3 months ago

Yes you'd absolutely have to.

This isn't yet implemented, I'll add this as a follow up.

Isn't it already exposed as take?

good point, this should end up in the same place (I'll double check that to be sure though)

How does this do with repeated indices? https://github.com/pydata/xarray/issues/2745

It places them in different chunks if you have too many of them. It looks at the average input chunk size along the dimension your are indexing and then uses this as the target chunk size for the output. It would preserve the 100, 100 chunks in the example in the other issue.

Screenshot 2024-08-05 at 17 38 15
dcherian commented 3 months ago

It places them in different chunks if you have too many of them.

Perfect! We'll need this for groupby-binary-ops (ds.groupby(...) - ds.groupby(...).mean())which has this total hack at the moment: https://github.com/pydata/xarray/blob/1ac19c4231e3e649392add503ae5a13d8e968aef/xarray/core/groupby.py#L655-L668

dcherian commented 3 months ago

Some updates after a chat with @phofl

First off, amazing progress on reducing memory usage in dask 2024.08.0!

TODO:

  1. We'll need a new GroupBy.shuffle() to allow users to manually shuffle the data so that all members of a group are guaranteed to be in one chunk. We may have more than one group in a single chunk.
    • xref #9320 .
    • Some followup work is needed in dask to rechunk the other axes if the chunks end up being too big.
    • ^ Once that's done, we can immediately upgrade GroupBy.median and GroupBy.quantile to make use of it.
  2. We should probably add a new GroupBy.transform that calls shuffle and then calls xarray.map_blocks with a wrapper function that applies another GroupBy with the UDF. This is similar to what dask.dataframe.GroupBy.map does.
    • 9706

  3. We need to make sure GroupBy.map and GroupBy.reduce use the new & improved dask.array.take algorithm.
    • Currently we are using Array.vindex which does its own thing.
    • @phofl will check if we can re-route Array.vindex through take.
  4. The next GroupBy challenge is that Xarray reshapes to 1D when grouping by a nD variable.
  5. We'll need benchmark problems.