xarray-contrib / flox

Fast & furious GroupBy operations for dask.array
https://flox.readthedocs.io
Apache License 2.0
124 stars 18 forks source link

improving the API for binned groupby #191

Open keewis opened 2 years ago

keewis commented 2 years ago

I've been trying to use flox for multi-dimensional binning and found the API a bit tricky to understand.

For some context, I have two variables (depth(time) and temperature(time)), which I'd like to bin into time_bounds(time, bounds) and depth_bounds(time, bounds).

I can get this to work using

arr = ds.set_coords("depth")["temperature"]
coords = [reference[name] for name in ["depth", "time"]]
vertices = [
    cf_xarray.bounds_to_vertices(reference[name], bounds_dim="bounds")
    for name in ["depth_bounds", "time_bounds"]
]
flox.xarray.xarray_reduce(
    arr,
    *coords,
    expected_groups=vertices,
    isbin=[True] * len(coords),
    func="mean",
)

but in the process of getting this right I frequently hit the Needs better message error from https://github.com/xarray-contrib/flox/blob/51fb6e9382a68854d2fec55da2ec4f67b0ed095b/flox/xarray.py#L219 which certainly did not help too much. However, ignoring that it was pretty difficult to make sense of the combination of *by, expected_groups, and isbin, and I'm not confident I won't be going through the same cycle of trial and error if I were to retry in a few months.

Instead, I wonder if we could change the call to something like:

bins = [
    flox.Bin(along=name, labels=reference[name], bounds=reference[f"{name}_bounds"])
    for name in ["depth", "time"]
]
flox.xarray.xarray_reduce(arr, *bins, func="mean")

(leaving aside the question of which bounds convention(s) this Bin object should support)

Another option might be to just use an interval index. Something like:

flox.xarray.xarray_reduce(arr, time=pd.IntervalIndex(...), depth=pd.IntervalIndex(...), func="mean")

That would be pretty close to the existing groupby interface. And we could even combine both:

flox.xarray.xarray_reduce(
    arr,
    time=flox.Bin(labels=reference[name], bounds=reference[f"{name}_bounds"]),
    depth=flox.Bin(labels=reference[name], bounds=reference[f"{name}_bounds"]),
    func="mean",
)

xref pydata/xarray#6610, where we probably want to adopt whatever signature we figure out here. Also, do tell me if you'd prefer to have this discussion in that issue instead (but figuring this out here might allow for quicker iteration). And maybe I'm trying to get xarray_reduce to do something too similar to groupby?

keewis commented 2 years ago

okay, turns out I seriously misunderstood what *by does... I was assuming I would need to pass a DataArray, but it seems that's only true for isbin=False. For isbin=True it instead takes a variable name. This is really confusing, though.

I guess something like this might be more intuitive?

# by as a mapping, to avoid shadowing variable names
flox.xarray.xarray_reduce(arr, by={"time": flox.Bins(...), "depth": flox.Bins(...)}, func="mean", ...)
# by as a *args
flox.xarray.xarray_reduce(
    arr,
    flox.Bins(variable="time", ...),
    flox.Bins(variable="depth", ...),
    func="mean",
    ...
)
dcherian commented 2 years ago

Thanks for trying it out Justus!

One reason for all this confusion is that I always expected xarray_reduce to be temporary while I ported it to Xarray but that is taking a lot of effort :( So now flox is stuck with some bad API choices.

I frequently hit the Needs better message error from

oops yeah. The issue is that I wanted to enable the really simple line:

xarray_reduce(ds, "x", expected_groups=["a", "b", "c"])

So for multiple groupers expected_groups must be a tuple. The error message should be improved ;). This would be a nice PR.

Another option might be to just use an interval index.

So flox.core.groupby_reduce does actually handle IntervalIndex (or any pandas Index).And looking at the code, xarray_reduce should also handle it. Here we convert everything to an Index: https://github.com/xarray-contrib/flox/blob/51fb6e9382a68854d2fec55da2ec4f67b0ed095b/flox/xarray.py#L318

And that function does handle pd.Index objects. I think we should update the typing and docstring. This would be a helpful PR!

edit ah now i see, we'll need to remove the isbin argument for this to work properly.

turns out I seriously misunderstood what *by does... I was assuming I would need to pass a DataArray, but it seems that's only true for isbin=False.

This sounds like a bug but I'm surprised. by can be Hashable or DataArray.

https://github.com/xarray-contrib/flox/blob/51fb6e9382a68854d2fec55da2ec4f67b0ed095b/flox/xarray.py#L232

and the tests: https://github.com/xarray-contrib/flox/blob/51fb6e9382a68854d2fec55da2ec4f67b0ed095b/tests/test_xarray.py#L111-L127

A reproducible example would help.

by={"time": flox.Bins(...), "depth": flox.Bins(...)}

Yeah I've actually been considering the dict approach. It seems a lot nicer since there's potentially a lot of information (for example cut_kwargs in xarray which gets passed to pd.cut which can contain whether the closed edge is left or right), and its hard to keep track of order in 3-4 different kwargs. One point of ugliness would be grouping by a categorical variable where you want flox to autodetect group members . That would then look like

by = {"labels": None}

It also means you can't do

xarray_reduce(ds, "x", func="mean")

An alternative is to use pd.IntervalIndex instead of a new flox.Bins object. It can wrap a lot of info and will avoid allow us to remove isbin (see flox.core.groupby_reduce).

keewis commented 2 years ago

here's the example I tested this with (although I just realized I could have used .cf.bounds_to_vertices instead of writing a new function):

In [1]: import xarray as xr
   ...: import cf_xarray
   ...: import numpy as np
   ...: import flox.xarray
   ...: 
   ...: 
   ...: def add_vertices(ds, bounds_dim="bounds"):
   ...:     new_names = {
   ...:         name: f"{name.removesuffix(bounds_dim).rstrip('_')}_vertices"
   ...:         for name, coord in ds.variables.items()
   ...:         if bounds_dim in coord.dims
   ...:     }
   ...:     new_coords = {
   ...:         new_name: cf_xarray.bounds_to_vertices(ds[name], bounds_dim=bounds_dim)
   ...:         for name, new_name in new_names.items()
   ...:     }
   ...:     return ds.assign_coords(new_coords)
   ...: 
   ...: 
   ...: categories = list("abcefghi")
   ...: 
   ...: coords = (
   ...:     xr.Dataset(coords={"x": np.arange(10), "y": ("y", categories)})
   ...:     .cf.add_bounds(["x"])
   ...:     .pipe(add_vertices)
   ...: )
   ...: coords
Out[1]: 
<xarray.Dataset>
Dimensions:     (x: 10, y: 8, bounds: 2, x_vertices: 11)
Coordinates:
  * x           (x) int64 0 1 2 3 4 5 6 7 8 9
  * y           (y) <U1 'a' 'b' 'c' 'e' 'f' 'g' 'h' 'i'
    x_bounds    (x, bounds) float64 -0.5 0.5 0.5 1.5 1.5 ... 7.5 7.5 8.5 8.5 9.5
  * x_vertices  (x_vertices) float64 -0.5 0.5 1.5 2.5 3.5 ... 6.5 7.5 8.5 9.5
Dimensions without coordinates: bounds
Data variables:
    *empty*

In [2]: data = xr.Dataset(
   ...:     {"a": ("x", np.arange(200))},
   ...:     coords={
   ...:         "x": np.linspace(-0.5, 9.5, 200),
   ...:         "y": ("x", np.random.choice(categories, size=200)),
   ...:     },
   ...: )
   ...: data
Out[2]: 
<xarray.Dataset>
Dimensions:  (x: 200)
Coordinates:
  * x        (x) float64 -0.5 -0.4497 -0.3995 -0.3492 ... 9.349 9.399 9.45 9.5
    y        (x) <U1 'e' 'g' 'b' 'b' 'h' 'a' 'g' ... 'a' 'c' 'c' 'h' 'c' 'b' 'h'
Data variables:
    a        (x) int64 0 1 2 3 4 5 6 7 8 ... 191 192 193 194 195 196 197 198 199

In [3]: flox.xarray.xarray_reduce(
   ...:     data["a"],
   ...:     coords["x"],
   ...:     expected_groups=(coords["x_vertices"],),
   ...:     isbin=[True] * 1,
   ...:     func="mean",
   ...: )
Out[3]: 
<xarray.DataArray 'a' (x_bins: 10)>
array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
Coordinates:
  * x_bins   (x_bins) object (-0.5, 0.5] (0.5, 1.5] ... (7.5, 8.5] (8.5, 9.5]

In [4]: flox.xarray.xarray_reduce(
   ...:     data["a"],
   ...:     "x",
   ...:     expected_groups=(coords["x_vertices"],),
   ...:     isbin=[True] * 1,
   ...:     func="mean",
   ...: )
Out[4]: 
<xarray.DataArray 'a' (x_bins: 10)>
array([ 10. ,  29.5,  49.5,  69.5,  89.5, 109.5, 129.5, 149.5, 169.5,
       189.5])
Coordinates:
  * x_bins   (x_bins) object (-0.5, 0.5] (0.5, 1.5] ... (7.5, 8.5] (8.5, 9.5]

the first call passes the dimension coordinate and does something weird, while the second call succeeds.

keewis commented 2 years ago

So for multiple groupers expected_groups must be a tuple. The error message should be improved ;).

Something else I noticed is that when passing a list with a single-element list containing a DataArray ([x]) to expected_groups, it would somehow convert that to [[x]] and would fail later.

It also means you can't do

We could get around that restriction by renaming *by to *variable_names (or something) and having a separate keyword-only argument named by. The only issue with that would be to decide what to do when it receives both the positional args and the dict (raise an error, probably). I'd actually also restrict the positional arguments to just the variable names.

pd.IntervalIndex works, although we need a good way to convert between the different bounds conventions: cf-xarray does provide conversion functions for bounds and vertices, but converting to / from intervals not as easy (pd.IntervalIndex.from_breaks only helps in one direction). And for IntervalIndex we definitely need the dict interface.

Edit: at the moment, the IntervalIndex coordinates flox adds to the output are not really useful: they can't really be used for selecting / aligning (though that might change with a custom index), and they also can't be written to disk, so I usually replace them with the interval centers and bounds

dcherian commented 2 years ago

We could get around that restriction by renaming by to variable_names

This is interesting but also messy. Presumably with the dict, everything else (expected_groups, isbin) should be None?

pd.IntervalIndex.from_breaks only helps in one direction the IntervalIndex coordinates flox adds to the output are not really useful:

It does what Xarray does at the moment (see output of groupby_bins). But yes sticking IntervalIndex in Xarray is something we should fix upstream :). Then we could add an accessor that provides .left .right and .mid (these already exist on the Index itself).

And for IntervalIndex we definitely need the dict interface.

Can you clarify? In core.groupby_reduce you provide IntervalIndex in expected_groups.

dcherian commented 2 years ago

the first call passes the dimension coordinate and does something weird

Ah I don't think this is a bug.data["x"] and coords["x"] are not the same so there's a reindexing step that happens as part of alignment.

I think I'll check for exact alignment.

keewis commented 2 years ago

Ah I don't think this is a bug.

Yeah, that was a user error. After reading the example from #189, I finally understand that *by should contain the variables / variable names to group over (e.g. coordinates of the data) and not the group labels. Not sure how I got that idea.

So with that, I now think the dict suggestion does not make too much sense, because you can pass non-hashable objects (DataArray).

Which means we're left with the custom object suggestion:

flox.xarray.xarray_reduce(
    data,
    flox.Grouper("x", bins=coords.x_vertices),
    flox.Grouper(data.y, values=["a", "b", "c"]),
)

which basically does the same thing as the combination of *by, expected_groups, and isbin but the information stays in a single place. I guess with that, expected_group and isbin could stay as they are, and if we mix both, the expected_groups and isbin should be None for the Grouper value.

But yes, the example helps quite a bit, so the custom object would not improve the API as much as I had thought.

Edit: but I still would like to have a convenient way to convert between different bounds conventions (bounds, vertices, interval index)

dcherian commented 2 years ago

but I still would like to have a convenient way to convert between different bounds conventions (bounds, vertices, interval index)

I think this is up to xarray/cf-xarray since it is a generally useful thing. Once xarray can stick intervalindex in Xarray objects, I think flox should just do that.

pd.cut does support a labels argument though, that might partially handle what you want.

dcherian commented 2 years ago

So with that, I now think the dict suggestion does not make too much sense, because you can pass non-hashable objects (DataArray).

Oh great point!

Which means we're left with the custom object suggestion:

flox.Grouper sounds good to me actually.

keewis commented 2 years ago

I did some experiments with the grouper object this afternoon. I'd imagine something like this (I'm not particularly fond of the bounds-type detection, but of course we can just require pd.IntervalIndex objects):

@attrs.define
class Grouper:
    """grouper for use with `flox.xarray.xarray_reduce`

    Parameter
    ---------
    over : hashable or DataArray
        The coordinate to group over. If a hashable, has to be a variable on the
        data. If a `DataArray`, its dimensions have to be a subset of the data's.
    values : list-like, array-like, or DataArray, optional
        The expected group labels. Mutually exclusive with `bins`.
    bins : DataArray or IntervalIndex, optional
        The bins used to group the data. Can either be a `IntervalIndex`, a `DataArray`
        of `n + 1` vertices, or a `DataArray` of `(n, 2)` bounds.

        Mutually exclusive with `values`.
    bounds_dim : hashable, optional
        The bounds dimension if the bins were passed as bounds.
    """

    over = attrs.field()

    bins = attrs.field(kw_only=True, default=None)
    values = attrs.field(kw_only=True, default=None)
    labels = attrs.field(init=False, default=None)

    bounds_dim = attrs.field(kw_only=True, default=None, repr=False)

    def __attrs_post_init__(self):
        if self.bins is not None and self.values is not None:
            raise TypeError("cannot specify both bins and group labels")

        if self.bins is not None:
            self.labels = to_intervals(self.bins, self.bounds_dim)
        elif self.values is not None:
            self.labels = self.values

    @property
    def is_bin(self):
        if self.labels is None:
            return None

        return self.bins is not None

def merge_lists(*lists):
    def merge_elements(elements):
        filtered = [element for element in elements if element is not None]

        return more_itertools.first(filtered, default=None)

    return [merge_elements(elements) for elements in zip(*lists)]

def groupby(obj, *by, **flox_kwargs):
    orig_expected_groups = flox_kwargs.get("expected_groups", [None] * len(by))
    orig_isbin = flox_kwargs.get("isbin", [None] * len(by))

    extracted = ((grouper.over, grouper.labels, grouper.is_bin) for grouper in by)
    by_, expected_groups, isbin = (list(_) for _ in more_itertools.unzip(extracted))
    flox_kwargs["expected_groups"] = tuple(
        merge_lists(orig_expected_groups, expected_groups)
    )
    flox_kwargs["isbin"] = merge_lists(orig_isbin, isbin)

    return flox.xarray.xarray_reduce(obj, *by_, **flox_kwargs)

flox.xarray.xarray_reduce(
    data,
    Grouper(over="x", bins=pd.IntervalIndex.from_breaks(coords["x_vertices"])),
    Grouper(over=data.y, values=["a", "b", "c"]),
    func="mean",
)
dcherian commented 2 years ago

pd.cut uses labels to specify output labels for binning. So we could have values and bins at the same time (though I prefer labels over values and expected_groups).

The nice thing is that we could also support pd.Grouper to do resampling, and other grouping at the same time.

I really like this Grouper idea. On a minor note, I think I prefer by (or key from pd.Grouper) instead of over (also see https://github.com/pydata/xarray/issues/324 )

keewis commented 2 years ago

well, I'm not really attached to the names, so that sounds good to me?