xarray-contrib / pint-xarray

Interface for using pint with xarray, providing convenience accessors
https://pint-xarray.readthedocs.io/en/latest/
Apache License 2.0
101 stars 12 forks source link

Chunked pint arrays break on `rolling()` #186

Open riley-brady opened 1 year ago

riley-brady commented 1 year ago

Hi folks,

I noticed that when running .rolling(...) on a chunked pint array, there is an exception raised that breaks the process:

TypeError: `pad_value` must be composed of integral typed values.

I outline three different cases below for running .rolling() on a pint-aware DataArray.

  1. Calculating the rolling sum on an in-memory pint array.
  2. Calculating the rolling sum on a chunked pint array, using xarray chunking.
    • This works, even without turning off bottleneck. However, this isn't an optimal solution for me, since one cannot query ds.pint.units on an xarray-chunked pint array. I like being able to do that for various QOL checks in a data pipeline.
  3. Calculating the rolling sum on a pint array chunked with ds.pint.chunk(...).
    • This method preserves the units, but leads to the traceback seen above and in full detail below. It also breaks when turning off bottleneck.
import pint_xarray
import xarray as xr
print(xr.__version__)
>>> '2022.6.0'
print(pint_xarray.__version__)
>>> '0.3'

data = xr.DataArray(range(3), dims='time').pint.quantify('kelvin')
print(data)
>>> <xarray.DataArray (time: 3)>
>>> <Quantity([0 1 2], 'kelvin')>

# Case 1: rolling sum with `pint` units. 
# Lose the units as expected, but executes properly.
rs = data.rolling(time=2).sum()
print(rs)
>>> <xarray.DataArray (time: 3)>
>>> array([nan,  1.,  3.])

# Case 2: rolling sum with `xr.chunk()`
# Maintain the units after compute, 
# but `data_xr_chunk.pint.units` returns `None` in the interim
data_xr_chunk = data.chunk({'time': 1})
rs = data_xr_chunk.rolling(time=2).sum().compute()
>>> <xarray.DataArray (time: 3)>
>>> <Quantity([nan  1.  3.], 'kelvin')>

# Case 3: rolling sum with `xr.pint.chunk()`
# Maintains units on chunked array, but raises exception
# (see full traceback below)
data_pint_chunk = data.pint.chunk({"time": 1})
rs = data_pint_chunk.rolling(time=2).sum().compute()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [31], in <cell line: 1>()
----> 1 rs = data_pint_chunk.rolling(time=2).sum().compute()

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:155, in Rolling._reduce_method.<locals>.method(self, keep_attrs, **kwargs)
    151 def method(self, keep_attrs=None, **kwargs):
    153     keep_attrs = self._get_keep_attrs(keep_attrs)
--> 155     return self._numpy_or_bottleneck_reduce(
    156         array_agg_func,
    157         bottleneck_move_func,
    158         rolling_agg_func,
    159         keep_attrs=keep_attrs,
    160         fillna=fillna,
    161         **kwargs,
    162     )

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:589, in DataArrayRolling._numpy_or_bottleneck_reduce(self, array_agg_func, bottleneck_move_func, rolling_agg_func, keep_attrs, fillna, **kwargs)
    586     kwargs.setdefault("skipna", False)
    587     kwargs.setdefault("fillna", fillna)
--> 589 return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:472, in DataArrayRolling.reduce(self, func, keep_attrs, **kwargs)
    470 else:
    471     obj = self.obj
--> 472 windows = self._construct(
    473     obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
    474 )
    476 result = windows.reduce(
    477     func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs
    478 )
    480 # Find valid windows based on count.

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:389, in DataArrayRolling._construct(self, obj, window_dim, stride, fill_value, keep_attrs, **window_dim_kwargs)
    384 window_dims = self._mapping_to_list(
    385     window_dim, allow_default=False, allow_allsame=False  # type: ignore[arg-type]  # https://github.com/python/mypy/issues/12506
    386 )
    387 strides = self._mapping_to_list(stride, default=1)
--> 389 window = obj.variable.rolling_window(
    390     self.dim, self.window, window_dims, self.center, fill_value=fill_value
    391 )
    393 attrs = obj.attrs if keep_attrs else {}
    395 result = DataArray(
    396     window,
    397     dims=obj.dims + tuple(window_dims),
   (...)
    400     name=obj.name,
    401 )

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/variable.py:2314, in Variable.rolling_window(self, dim, window, window_dim, center, fill_value)
   2311     else:
   2312         pads[d] = (win - 1, 0)
-> 2314 padded = var.pad(pads, mode="constant", constant_values=fill_value)
   2315 axis = [self.get_axis_num(d) for d in dim]
   2316 new_dims = self.dims + tuple(window_dim)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/variable.py:1416, in Variable.pad(self, pad_width, mode, stat_length, constant_values, end_values, reflect_type, **pad_width_kwargs)
   1413 if reflect_type is not None:
   1414     pad_option_kwargs["reflect_type"] = reflect_type
-> 1416 array = np.pad(  # type: ignore[call-overload]
   1417     self.data.astype(dtype, copy=False),
   1418     pad_width_by_index,
   1419     mode=mode,
   1420     **pad_option_kwargs,
   1421 )
   1423 return type(self)(self.dims, array)

File <__array_function__ internals>:180, in pad(*args, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/quantity.py:1730, in Quantity.__array_function__(self, func, types, args, kwargs)
   1729 def __array_function__(self, func, types, args, kwargs):
-> 1730     return numpy_wrap("function", func, args, kwargs, types)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/numpy_func.py:936, in numpy_wrap(func_type, func, args, kwargs, types)
    934 if name not in handled or any(is_upcast_type(t) for t in types):
    935     return NotImplemented
--> 936 return handled[name](*args, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/numpy_func.py:660, in _pad(array, pad_width, mode, **kwargs)
    656     if key in kwargs:
    657         kwargs[key] = _recursive_convert(kwargs[key], units)
    659 return units._REGISTRY.Quantity(
--> 660     np.pad(array._magnitude, pad_width, mode=mode, **kwargs), units
    661 )

File <__array_function__ internals>:180, in pad(*args, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/core.py:1762, in Array.__array_function__(self, func, types, args, kwargs)
   1759 if has_keyword(da_func, "like"):
   1760     kwargs["like"] = self
-> 1762 return da_func(*args, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:1229, in pad(array, pad_width, mode, **kwargs)
   1227 elif mode == "constant":
   1228     kwargs.setdefault("constant_values", 0)
-> 1229     return pad_edge(array, pad_width, mode, **kwargs)
   1230 elif mode == "linear_ramp":
   1231     kwargs.setdefault("end_values", 0)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:964, in pad_edge(array, pad_width, mode, **kwargs)
    957 def pad_edge(array, pad_width, mode, **kwargs):
    958     """
    959     Helper function for padding edges.
    960 
    961     Handles the cases where the only the values on the edge are needed.
    962     """
--> 964     kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
    966     result = array
    967     for d in range(array.ndim):

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:964, in <dictcomp>(.0)
    957 def pad_edge(array, pad_width, mode, **kwargs):
    958     """
    959     Helper function for padding edges.
    960 
    961     Handles the cases where the only the values on the edge are needed.
    962     """
--> 964     kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
    966     result = array
    967     for d in range(array.ndim):

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:910, in expand_pad_value(array, pad_value)
    908     pad_value = array.ndim * (tuple(pad_value[0]),)
    909 else:
--> 910     raise TypeError("`pad_value` must be composed of integral typed values.")
    912 return pad_value

TypeError: `pad_value` must be composed of integral typed values.

My solution in the interim is to do something like:

units = data.pint.units
data = data.pint.dequantify()
rs = data.rolling(time=2)
rs = rs.pint.quantify(units)
riley-brady commented 1 year ago

Another side note that came up here -- I'm curious if there's any roadmap plan for recognizing integration of units for methods like rolling().sum().

E.g.,

data = xr.DataArray(range(3), dims='time').pint.quantify('mm/day')
data.pint.units
>>> mm/day
data = data.rolling(time=2).sum()
data.pint.units
>>> mm
keewis commented 1 year ago

thanks for the report, @riley-brady. It seems that xarray operations on pint+dask are not as thoroughly tested as pint and dask on their own. I think this is a bug in pint (or dask, not sure): we enable force_ndarray_like to convert scalars to 0d arrays, which means that the final call to np.pad becomes:

np.pad(magnitude, pad_width, mode="constant", constant_values=np.array(0))

numpy seems to be fine with that, but dask is not.

@jrbourbeau, what do you think? Would it make sense to extend expand_pad_value to unpack 0d arrays (using .item()), or would you rather have the caller (pint, in this case) do that?

keewis commented 1 year ago

I'm curious if there's any roadmap plan for recognizing integration of units for methods like rolling().sum()

I'm not sure I follow. Why would rolling().sum() work similar to integration, when all it does is compute a grouped sum? I'm not sure if this actually counts as integration, but you can multiply the result of the rolling sum with the diff of the time coordinate (which is a bit tricky because time is an indexed coordinate):

data = xr.DataArray(
    range(3), dims="time", coords={"time2": ("time", [1, 2, 3])}
).pint.quantify("mm/day", time="day")
dt = data.time2.pad(time=(1, 0)).diff(dim="time")
data.rolling(time=2).sum() * dt

and then you would have the correct units (with the same numerical result, because I chose the time coordinate to have increments of 1 day)

riley-brady commented 1 year ago

Thanks for the quick feedback on this issue @keewis.

Also thanks for the demo with .diff(). You're right about the integration assumptions. In my specific use case I am doing a rolling sum of units mm/day with daily time steps, so in this case it should reflect total precip in mm, but that's not a fair assumption for many other cases. I'll give the .diff() method a try.

keewis commented 8 months ago

this should be fixed in dask since quite a while ago, but I'll leave it open until we have tests for this (probably after copying the test suite from xarray)