pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.63k stars 1.09k forks source link

Cannot apply `sum` to custom object arrays with `min_count=1` if n_dims <= sum_dims #9755

Open brynpickering opened 1 week ago

brynpickering commented 1 week ago

What happened?

I'm storing Pyomo objects in xarray arrays to allow for vectorised operations with named arrays. E.g., I might have an array of Pyomo variables indexed over the costs dimension and I want to sum them. Before v2024.6.0 this was possible without any issues. Now, if I include the min_count argument in my summation then it fails when trying to apply numpy type promotion to the Pyomo object.

What did you expect to happen?

The xarray.core.array_api_compat.result_type method should fall back to numpy.object_ when objects are present.

Minimal Complete Verifiable Example

import pyomo.kernel as pk
import xarray as xr

da = xr.DataArray([np.nan, pk.variable(lb=0, ub=10), np.nan], dims=("costs",))

try:
    da.sum("costs")
    print("Success without min_count")
except TypeError as e:
    print(f"Failure without min_count: {e}")

try:
    da.sum("costs", min_count=1)
    print("Success with min_count")
except TypeError as e:
    print(f"Failure with min_count: {e}")

try:
    da.expand_dims(foo=["bar"]).sum("costs", min_count=1)
    print("Success with extra dim and min_count")
except TypeError as e:
    print(f"Failure with extra dim and min_count: {e}")

MVCE confirmation

Relevant log output

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/_aggregations.py:3175, in DataArrayAggregations.sum(self, dim, skipna, min_count, keep_attrs, **kwargs)
   3087 def sum(
   3088     self,
   3089     dim: Dims = None,
   (...)
   3094     **kwargs: Any,
   3095 ) -> Self:
   3096     """
   3097     Reduce this DataArray's data by applying ``sum`` along some dimension(s).
   3098 
   (...)
   3173     array(8.)
   3174     """
-> 3175     return self.reduce(
   3176         duck_array_ops.sum,
   3177         dim=dim,
   3178         skipna=skipna,
   3179         min_count=min_count,
   3180         keep_attrs=keep_attrs,
   3181         **kwargs,
   3182     )

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/dataarray.py:3839, in DataArray.reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   3795 def reduce(
   3796     self,
   3797     func: Callable[..., Any],
   (...)
   3803     **kwargs: Any,
   3804 ) -> Self:
   3805     """Reduce this array by applying `func` along some dimension(s).
   3806 
   3807     Parameters
   (...)
   3836         summarized data and the indicated dimension(s) removed.
   3837     """
-> 3839     var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
   3840     return self._replace_maybe_drop_dims(var)

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/variable.py:1677, in Variable.reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   1670 keep_attrs_ = (
   1671     _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs
   1672 )
   1674 # Note that the call order for Variable.mean is
   1675 #    Variable.mean -> NamedArray.mean -> Variable.reduce
   1676 #    -> NamedArray.reduce
-> 1677 result = super().reduce(
   1678     func=func, dim=dim, axis=axis, keepdims=keepdims, **kwargs
   1679 )
   1681 # return Variable always to support IndexVariable
   1682 return Variable(
   1683     result.dims, result._data, attrs=result._attrs if keep_attrs_ else None
   1684 )

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/namedarray/core.py:916, in NamedArray.reduce(self, func, dim, axis, keepdims, **kwargs)
    912     if isinstance(axis, tuple) and len(axis) == 1:
    913         # unpack axis for the benefit of functions
    914         # like np.argmin which can't handle tuple arguments
    915         axis = axis[0]
--> 916     data = func(self.data, axis=axis, **kwargs)
    917 else:
    918     data = func(self.data, **kwargs)

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/duck_array_ops.py:447, in _create_nan_agg_method.<locals>.f(values, axis, skipna, **kwargs)
    445     with warnings.catch_warnings():
    446         warnings.filterwarnings("ignore", "All-NaN slice encountered")
--> 447         return func(values, axis=axis, **kwargs)
    448 except AttributeError:
    449     if not is_duck_dask_array(values):

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/nanops.py:101, in nansum(a, axis, dtype, out, min_count)
     99 result = sum_where(a, axis=axis, dtype=dtype, where=mask)
    100 if min_count is not None:
--> 101     return _maybe_null_out(result, axis, mask, min_count)
    102 else:
    103     return result

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/nanops.py:34, in _maybe_null_out(result, axis, mask, min_count)
     32 elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
     33     null_mask = mask.size - duck_array_ops.sum(mask)
---> 34     result = where(null_mask < min_count, np.nan, result)
     36 return result

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/duck_array_ops.py:354, in where(condition, x, y)
    352 """Three argument where() with better dtype promotion rules."""
    353 xp = get_array_namespace(condition)
--> 354 return xp.where(condition, *as_shared_dtype([x, y], xp=xp))

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/duck_array_ops.py:258, in as_shared_dtype(scalars_or_arrays, xp)
    252     xp = get_array_namespace(scalars_or_arrays)
    254 # Pass arrays directly instead of dtypes to result_type so scalars
    255 # get handled properly.
    256 # Note that result_type() safely gets the dtype from dask arrays without
    257 # evaluating them.
--> 258 dtype = dtypes.result_type(*scalars_or_arrays, xp=xp)
    260 return [asarray(x, dtype=dtype, xp=xp) for x in scalars_or_arrays]

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/dtypes.py:266, in result_type(xp, *arrays_and_dtypes)
    262 if xp is None:
    263     xp = get_array_namespace(arrays_and_dtypes)
    265 types = {
--> 266     array_api_compat.result_type(preprocess_types(t), xp=xp)
    267     for t in arrays_and_dtypes
    268 }
    269 if any(isinstance(t, np.dtype) for t in types):
    270     # only check if there's numpy dtypes – the array API does not
    271     # define the types we're checking for
    272     for left, right in PROMOTE_TO_OBJECT:

File ~/miniforge3/envs/calliope/lib/python3.12/site-packages/xarray/core/array_api_compat.py:42, in result_type(xp, *arrays_and_dtypes)
     38 def result_type(*arrays_and_dtypes, xp) -> np.dtype:
     39     if xp is np or any(
     40         isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes
     41     ):
---> 42         return xp.result_type(*arrays_and_dtypes)
     43     else:
     44         return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)

TypeError: Cannot interpret '<pyomo.core.kernel.variable.variable object at 0x16b1dc7d0>' as a data type

Anything else we need to know?

No response

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ] python-bits: 64 OS: Darwin OS-release: 23.2.0 machine: arm64 processor: arm byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.14.3 libnetcdf: 4.9.2 xarray: 2024.10.0 pandas: 2.2.2 numpy: 2.0.2 scipy: 1.14.0 netCDF4: 1.6.5 pydap: None h5netcdf: None h5py: None zarr: 2.18.2 cftime: 1.6.4 nc_time_axis: None iris: None bottleneck: 1.4.0 dask: 2024.8.1 distributed: None matplotlib: None cartopy: None seaborn: None numbagg: None fsspec: 2024.6.1 cupy: None pint: None sparse: 0.15.4 flox: 0.9.8 numpy_groupies: 0.11.1 setuptools: 70.0.0 pip: 24.0 conda: None pytest: 8.2.2 mypy: None IPython: 8.25.0