pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.62k stars 1.08k forks source link

vectorized groupby binary ops #5804

Closed dcherian closed 2 years ago

dcherian commented 3 years ago

By switching to numpy_groupies we are vectorizing our groupby reductions. I think we can do the same for groupby's binary ops.

Here's an example array

import numpy as np
import xarray as xr

%load_ext memory_profiler

N = 4 * 2000
da = xr.DataArray(
    np.random.random((N, N)),
    dims=("x", "y"),
    coords={"labels": ("x", np.repeat(["a", "b", "c", "d", "e", "f", "g", "h"], repeats=N//8))},
)

Consider this "anomaly" calculation, anomaly defined relative to the group mean

def anom_current(da):
    grouped = da.groupby("labels")
    mean = grouped.mean()
    anom = grouped - mean
    return anom

With this approach, we loop over each group and apply the binary operation: https://github.com/pydata/xarray/blob/a1635d324753588e353e4e747f6058936fa8cf1e/xarray/core/computation.py#L502-L525

This saves some memory, but becomes slow for large number of groups.

We could instead do

def anom_vectorized(da):
    mean = da.groupby("labels").mean()
    mean_expanded = mean.sel(labels=da.labels)
    anom = da - mean_expanded
    return anom

Now we are faster, but construct an extra array as big as the original array (I think this is an OK tradeoff).

%timeit anom_current(da)

# 1.4 s ± 20.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit anom_vectorized(da)

# 937 ms ± 5.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

(I haven't experimented with dask yet, so the following is just a theory).

I think the real benefit comes with dask. Depending on where the groups are located relative to chunking, we could end up creating a lot of tiny chunks by splitting up existing chunks. With the vectorized approach we can do better.

Ideally we would reindex the "mean" dask array with a numpy-array-of-repeated-ints such that the chunking of mean_expanded exactly matches the chunking of da along the grouped dimension.

~In practice, dask.array.take doesn't allow specifying "output chunks" so we'd end up chunking "mean_expanded" based on dask's automatic heuristics, and then rechunking again for the binary operation.~

Thoughts?

cc @rabernat

shoyer commented 3 years ago

I agree, I think this would be a much cleaner way to do these sorts of operations!