pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.56k stars 1.07k forks source link

`groupby().max()` with Flox: `dtype` keyword not working. #9398

Closed bertcoerver closed 4 days ago

bertcoerver commented 1 month ago

What happened?

I'm applying a reduction on some data, but the dtype keyword documented here doesn't seem to work, without also specifying a fill_value.

Not sure if it's a bug, but otherwise it would be nice to make a note of this in the documentation.

What did you expect to happen?

I expect the output to have the dtype specified by the dtype keyword.

Minimal Complete Verifiable Example

Given a dataset with int16 data:

ds = xr.Dataset(
          {"test": (["x", "y"], np.array([[1,2,3], [4,5,6], [7,8,9]], dtype = "int16"))}, 
          coords = {"idx": ("x", [1,2,1])})

The following casts the "test" variable to float64 (as expected):

ds.groupby("idx").max(method = "blockwise", engine = "flox")

However, doing this:

ds.groupby("idx").max(method = "blockwise", engine = "flox", dtype = "int16")

Also gives float64 as output dtype.

Only once I also specify a fill_value, the output follows the specified dtype, e.g.:

x = ds.groupby("idx").max(method = "blockwise", engine = "flox", dtype = "int16", fill_value = 99)

gives x["test"].dtype.name == "int16", or:

x = ds.groupby("idx").max(method = "blockwise", engine = "flox", dtype = "int32", fill_value = 99)

gives x["test"].dtype.name == "int32"

MVCE confirmation

Relevant log output

No response

Anything else we need to know?

No response

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:54:21) [Clang 16.0.6 ] python-bits: 64 OS: Darwin OS-release: 23.5.0 machine: arm64 processor: arm byteorder: little LC_ALL: None LANG: None LOCALE: (None, 'UTF-8') libhdf5: 1.14.3 libnetcdf: 4.9.2 xarray: 2024.2.0 pandas: 2.2.1 numpy: 1.26.4 scipy: 1.12.0 netCDF4: 1.6.5 pydap: None h5netcdf: None h5py: None Nio: None zarr: None cftime: 1.6.3 nc_time_axis: None iris: None bottleneck: 1.3.8 dask: 2024.6.2 distributed: 2024.6.2 matplotlib: 3.8.3 cartopy: None seaborn: None numbagg: 0.8.1 fsspec: 2024.2.0 cupy: None pint: None sparse: None flox: 0.9.10 numpy_groupies: 0.11.2 setuptools: 70.2.0 pip: 24.1.2 conda: None pytest: 8.0.2 mypy: None IPython: 8.22.1 sphinx: 5.3.0
dcherian commented 1 month ago

Yes the issue is here: https://github.com/pydata/xarray/blob/a04d857a03d1fb04317d636a7f23239cb9034491/xarray/core/groupby.py#L751-L752

Setting np.nan will promote the output to default float type for your platform. That's why setting dtype alone doesn't work.

Using flox directly preserves the dtype without any extra work:

import flox.xarray

flox.xarray.xarray_reduce(ds, "idx", func="max")

so perhaps use that for now.