pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.64k stars 1.09k forks source link

Reduction operations fail with Weighted DataArrayWeighted #9841

Open 3tilley opened 3 days ago

3tilley commented 3 days ago

What happened?

I have a DataSet with some weighted DataArrays. This set-up is extremely useful to me as I can filter and perform operations over the whole dataset and all shared dimensions. One of the DataArrays is weighted, and I was hoping this would be automatically handled in groupbys and general reduction operations, but the error thrown is below. If I call mean on the dataset.

I'm happy to raise a PR to fix if I can work out how to do it, but I just want to make sure that it's agreed that this isn't correct behaviour.

What did you expect to happen?

I would like DataArrays that are unweighted to return the usual mean, and for DataArrayWeighted to return a mean reflecting their weights, as if I'd just called da_weighted.mean(). This would allow me to calculate means in groupbys on the DataSet.

Minimal Complete Verifiable Example

import numpy as np
import xarray as xr

da = xr.DataArray(
    data=[[4.0, 5.0, 6.0], [1.0, 2.0, np.nan], [np.nan, np.nan, np.nan]],
    dims=["t", "x"],
    coords={"t": [0, 1, 2], "x": ["a", "b", "c"]}
)

dw = xr.DataArray(
    data=[0.1, 0.2, 0.3],
    dims=["t"],
    coords={"t": [0, 1, 2]}
)
db = da.copy(deep=True).weighted(dw)
ds = xr.Dataset({"a": da, "b": db, "w": dw})

# This works
print(db.mean())

# Errors
print(ds["b"].mean())

# Errors
print(ds.mean(dim="t"))

# Errors
print(ds.groupby_bins("t", bins=[0, 2.5]).mean())

MVCE confirmation

Relevant log output

python weighted_demo.py
<xarray.DataArray ()> Size: 8B
array(3.)
Traceback (most recent call last):
  File <redact>, line 22, in <module>
    print(ds["b"].mean())
          ^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/_aggregations.py", line 2982, in mean
    return self.reduce(
           ^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/dataarray.py", line 3839, in reduce
    var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/variable.py", line 1677, in reduce
    result = super().reduce(
             ^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/namedarray/core.py", line 918, in reduce
    data = func(self.data, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/duck_array_ops.py", line 680, in mean
    return _mean(array, axis=axis, skipna=skipna, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/duck_array_ops.py", line 447, in f
    return func(values, axis=axis, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/nanops.py", line 124, in nanmean
    return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/nanops.py", line 117, in _nanmean_ddof_object
    data = np.sum(value, axis=axis, dtype=dtype, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 2389, in sum
    return _wrapreduction(
           ^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 86, in _wrapreduction
    return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: float() argument must be a string or a real number, not 'DataArrayWeighted'

Anything else we need to know?

There are several functions that might fall into this category like std, but I think they could all be handled similarly.

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.12.3 (main, Jul 31 2024, 17:43:48) [GCC 13.2.0] python-bits: 64 OS: Linux OS-release: 5.15.153.1-microsoft-standard-WSL2 machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: C.UTF-8 LOCALE: ('C', 'UTF-8') libhdf5: 1.14.4 libnetcdf: None xarray: 2024.10.0 pandas: 2.2.3 numpy: 2.0.2 scipy: 1.14.1 netCDF4: None pydap: None h5netcdf: 1.4.0 h5py: 3.12.1 zarr: None cftime: None nc_time_axis: None iris: None bottleneck: 1.4.2 dask: None distributed: None matplotlib: 3.9.2 cartopy: None seaborn: 0.13.2 numbagg: None fsspec: None cupy: None pint: None sparse: 0.15.4 flox: None numpy_groupies: None setuptools: 75.3.0 pip: 24.0 conda: None pytest: 8.3.3 mypy: 1.13.0 IPython: 8.29.0 sphinx: None
welcome[bot] commented 3 days ago

Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!

max-sixty commented 1 day ago

Very much agree with your suggestion.

It's likely worth spiking a quick example of what the code changes would look like. Responding quickly from memory — it's plausible that the current dataset code doesn't expect data variables to be of different types, and the dataset handles the .mean operation itself rather than delegating it down to each data variable. If that's the case, it might require quite some large-ish changes with lots of if isweighted, which wouldn't be great.

If it does delegate the .mean to each data variable (or we could change it to do that), then this could work quite nicely. And might also be generalizable to reduction operations on other arrays, such as sparse arrays...

Does that make sense?

dcherian commented 23 hours ago

I would like DataArrays that are unweighted to return the usual mean,

I don't think this is a good idea. Consider the case when the weights Dataset is mistakenly missing a couple of data vars. Then you'll unintentionally get unweighted means and not know about it!

You might consider simply adding a scalar 1 in the weights Dataset for any missing data var.

max-sixty commented 19 hours ago

Consider the case when the weights Dataset is mistakenly missing a couple of data vars. Then you'll unintentionally get unweighted means and not know about it!

I'm interpreting this differently — the dataset has some data variables that are weighted and some that are unweighted. There's no ds.weighted(ds_weights) where a missing data variable in ds_weights creates an unweighted data variable?

Instead it's db = da.weighted(dw), where db is an array, and that array is assigned to the dataset.

(when I'm confused during a discussion of ours, it's 3 times out of 4 me who's missing something, so asking from the perspective of likely being wrong but hopefully nonetheless helpful)