pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.55k stars 1.07k forks source link

Different behavior of groupby with / without flox #9279

Closed max-sixty closed 1 month ago

max-sixty commented 1 month ago

What happened?

(Following up from https://github.com/pydata/xarray/issues/8263)

Without flox:


[ins] In [1]: da =  xr.tutorial.open_dataset("air_temperature")['air']
         ...:
         ...: da.drop_vars('lat').groupby('lat').sum()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 3
      1 da =  xr.tutorial.open_dataset("air_temperature")['air']
----> 3 da.drop_vars('lat').groupby('lat').sum()

File ~/workspace/xarray/xarray/core/_aggregations.py:6176, in DataArrayGroupByAggregations.sum(self, dim, skipna, min_count, keep_attrs, **kwargs)
   6166     return self._flox_reduce(
   6167         func="sum",
   6168         dim=dim,
   (...)
   6173         **kwargs,
   6174     )
   6175 else:
-> 6176     return self._reduce_without_squeeze_warn(
   6177         duck_array_ops.sum,
   6178         dim=dim,
   6179         skipna=skipna,
   6180         min_count=min_count,
   6181         keep_attrs=keep_attrs,
   6182         **kwargs,
   6183     )

File ~/workspace/xarray/xarray/core/groupby.py:1417, in DataArrayGroupByBase._reduce_without_squeeze_warn(self, func, dim, axis, keep_attrs, keepdims, shortcut, **kwargs)
   1415 with warnings.catch_warnings():
   1416     warnings.filterwarnings("ignore", message="The `squeeze` kwarg")
-> 1417     check_reduce_dims(dim, self.dims)
   1419 return self._map_maybe_warn(reduce_array, shortcut=shortcut, warn_squeeze=False)

File ~/workspace/xarray/xarray/core/groupby.py:66, in check_reduce_dims(reduce_dims, dimensions)
     64     reduce_dims = [reduce_dims]
     65 if any(dim not in dimensions for dim in reduce_dims):
---> 66     raise ValueError(
     67         f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' "
     68         f"to reduce over all dimensions or one or more of {dimensions!r}."
     69         f" Try passing .groupby(..., squeeze=False)"
     70     )

ValueError: cannot reduce over dimensions ['lat']. expected either '...' to reduce over all dimensions or one or more of ('time', 'lon'). Try passing .groupby(..., squeeze=False)

But with flox:

[ins] In [1]: da =  xr.tutorial.open_dataset("air_temperature")['air']
         ...:
         ...: da.drop_vars('lat').groupby('lat').sum()
Out[1]:
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB
array([[[241.2 , 242.5 , 243.5 , ..., 232.8 , 235.5 , 238.6 ],
        [243.8 , 244.5 , 244.7 , ..., 232.8 , 235.3 , 239.3 ],
        [250.  , 249.8 , 248.89, ..., 233.2 , 236.39, 241.7 ],
        ...,

What did you expect to happen?

Identical behavior

Minimal Complete Verifiable Example

As above

MVCE confirmation

Relevant log output

No response

Anything else we need to know?

No response

Environment

INSTALLED VERSIONS ------------------ commit: d0048ef8e5c07b5290e9d2c4c144d23a45ab91f7 python: 3.11.9 (main, Apr 2 2024, 08:25:04) [Clang 15.0.0 (clang-1500.3.9.4)] python-bits: 64 OS: Darwin OS-release: 23.5.0 machine: arm64 processor: arm byteorder: little LC_ALL: en_US.UTF-8 LANG: None LOCALE: ('en_US', 'UTF-8') libhdf5: 1.12.2 libnetcdf: 4.9.3-development xarray: 2024.6.1.dev50+g7b08a948 pandas: 2.2.2 numpy: 1.26.4 scipy: 1.13.1 netCDF4: 1.6.5 pydap: None h5netcdf: 1.3.0 h5py: 3.11.0 zarr: 2.18.2 cftime: 1.6.3 nc_time_axis: 1.4.1 iris: None bottleneck: 1.3.8 dask: 2024.7.0 distributed: 2024.7.0 matplotlib: 3.9.0 cartopy: None seaborn: 0.13.2 numbagg: 0.8.1 fsspec: 2024.5.0 cupy: None pint: None sparse: None flox: 0.9.8 numpy_groupies: 0.11.1 setuptools: 69.2.0 pip: 24.0 conda: None pytest: 8.2.1 mypy: 1.8.0 IPython: 8.25.0 sphinx: None
dcherian commented 1 month ago

The error message mentions: Try passing .groupby(..., squeeze=False) which ends up working. This will be fixed once we finally make that the default. It's been warning for 7 months now.

dcherian commented 1 month ago

Or we can fix the bug here: https://github.com/pydata/xarray/blob/d0048ef8e5c07b5290e9d2c4c144d23a45ab91f7/xarray/core/groupby.py#L812-L818

We need to raise the error for virtual variables too.

max-sixty commented 1 month ago

OK great!

(so I understand — when we change the default, would da.groupby(dim).mean() work — i.e. with no dim passed to mean? Or we'll still mandate a dim passed to mean? Currently da.groupby('time').mean() ~fails~ passes but da.groupby('lat').mean() fails, seemingly because lat is a float index... )

dcherian commented 1 month ago

Turns out it has to do with sortedness of lat. We only check.is_monotonic_increasing but yes if I pass squeeze=False it all works.

https://github.com/pydata/xarray/blob/d0048ef8e5c07b5290e9d2c4c144d23a45ab91f7/xarray/core/groupers.py#L122-L126

Could switch it to add is_monotonic_decreasing too.

https://github.com/pydata/xarray/issues/6220 https://github.com/pydata/xarray/pull/7427

max-sixty commented 1 month ago

OK great!

So it sounds like after we make that change, it won't be required to pass any dimensions to the aggregation func — da.groupby('lat').mean(), (with nothing passed to mean()), will aggregate over groups along lat without reducing any other dims. (lmk if I'm mistaken)

That'll be much better than the current state, where not passing anything to mean() works depending on the dimension's sortedness + whether flox is installed. Notably, if we only add is_monotonically_decreasing without making squeeze=False the default, it'll still depend on the array sortedness.

TY!

dcherian commented 1 month ago

Setting squeeze=False makes it always work.

7427 removes the dependence on sortedness. That is, once merged, it will dependably error on main. Perhaps we should just delete the squeeze kwarg, it's been a pain for quite long.

max-sixty commented 1 month ago

OK v nice — on #7427 this dependably works on both flox & no flox, both ordered & unordered.

with xr.set_options(use_flox=True):
    print(da.assign_coords(lat=lambda x: x.lat % 2).groupby('lat', squeeze=False).sum())
max-sixty commented 1 month ago

Closed by #9280, thanks !