xarray-contrib / flox

Fast & furious GroupBy operations for dask.array
https://flox.readthedocs.io
Apache License 2.0
123 stars 16 forks source link

Issues when running `flox` with chunked pint arrays #163

Open riley-brady opened 2 years ago

riley-brady commented 2 years ago

After upgrading to the latest xarray version and installing flox, I find that chunked pint arrays break with the .resample() method. I'm posting this here instead of pint_xarray since it looks like from the traceback this is coming from flox.

I imagine this has to do with the complexity of working with duck-arrays like pint_xarray.

Possible related threads:

import xarray as xr
import pint_xarray
import flox

xr.__version__
>>> '2022.9.0'

pint_xarray.__version__
>>> '0.3'

flox.__version__
>>> '0.5.9'

time_ax = xr.cftime_range('2020-06-01 01:00:00', freq='H', periods=3)
ds = xr.DataArray(range(3), dims='time', coords={'time': time_ax})

# Simple case, no dask or pint
ds.resample(time="D").mean()
>>> <xarray.DataArray (time: 1)>
>>> array([1.])
>>> Coordinates:
>>>   * time     (time) object 2020-06-01 00:00:00

# Dask case
ds_chunked = ds.chunk({'time': 1})
ds_chunked.resample(time="D").mean().compute()
>>> <xarray.DataArray (time: 1)>
>>> array([1.])
>>> Coordinates:
>>>   * time     (time) object 2020-06-01 00:00:00

# Pint case
ds_pint = ds.pint.quantify('kelvin')
ds_pint.resample(time="D").mean()
>>> <xarray.DataArray (time: 1)>
>>> <Quantity([1.], 'kelvin')>
>>> Coordinates:
>>>  * time     (time) object 2020-06-01 00:00:00

# Pint with xarray chunk
ds_pint_chunk = ds_pint.chunk({'time': 1})
ds_pint_chunk.resample(time="D").mean().compute()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [25], in <cell line: 3>()
      1 # Pint with xarray chunk
      2 ds_pint_chunk = ds_pint.chunk({'time': 1})
----> 3 ds_pint_chunk.resample(time="D").mean().compute()

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/dataarray.py:1083, in DataArray.compute(self, **kwargs)
   1064 """Manually trigger loading of this array's data from disk or a
   1065 remote source into memory and return a new array. The original is
   1066 left unaltered.
   (...)
   1080 dask.compute
   1081 """
   1082 new = self.copy(deep=False)
-> 1083 return new.load(**kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/dataarray.py:1057, in DataArray.load(self, **kwargs)
   1039 def load(self: T_DataArray, **kwargs) -> T_DataArray:
   1040     """Manually trigger loading of this array's data from disk or a
   1041     remote source into memory and return this array.
   1042 
   (...)
   1055     dask.compute
   1056     """
-> 1057     ds = self._to_temp_dataset().load(**kwargs)
   1058     new = self._from_temp_dataset(ds)
   1059     self._variable = new._variable

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/dataset.py:734, in Dataset.load(self, **kwargs)
    731 import dask.array as da
    733 # evaluate all the dask arrays simultaneously
--> 734 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    736 for k, data in zip(lazy_data, evaluated_data):
    737     self.variables[k].data = data

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/base.py:600, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    597     keys.append(x.__dask_keys__())
    598     postcomputes.append(x.__dask_postcompute__())
--> 600 results = schedule(dsk, keys, **kwargs)
    601 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state["cache"][key] = res

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
    988 if not len(args) == len(self.inkeys):
    989     raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/core.py:149, in get(dsk, out, cache)
    147 for key in toposort(dsk):
    148     task = dsk[key]
--> 149     result = _execute_task(task, cache)
    150     cache[key] = result
    151 result = _execute_task(out, cache)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/toolz/functoolz.py:487, in Compose.__call__(self, *args, **kwargs)
    486 def __call__(self, *args, **kwargs):
--> 487     ret = self.first(*args, **kwargs)
    488     for f in self.funcs:
    489         ret = f(ret)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/flox/core.py:689, in chunk_reduce(array, by, func, expected_groups, axis, fill_value, dtype, reindex, engine, kwargs, sort)
    687     result = reduction(group_idx, array, **kwargs)
    688 else:
--> 689     result = generic_aggregate(
    690         group_idx, array, axis=-1, engine=engine, func=reduction, **kwargs
    691     ).astype(dt, copy=False)
    692 if np.any(props.nanmask):
    693     # remove NaN group label which should be last
    694     result = result[..., :-1]

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/flox/aggregations.py:49, in generic_aggregate(group_idx, array, engine, func, axis, size, fill_value, dtype, **kwargs)
     44 else:
     45     raise ValueError(
     46         f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
     47     )
---> 49 return method(
     50     group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
     51 )

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/flox/aggregate_flox.py:33, in _np_grouped_op(group_idx, array, op, axis, size, fill_value, dtype, out)
     26     out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
     28 if (len(uniques) == size) and (uniques == np.arange(size)).all():
     29     # The previous version of this if condition
     30     #     ((uniques[1:] - uniques[:-1]) == 1).all():
     31     # does not work when group_idx is [1, 2] for e.g.
     32     # This happens  during binning
---> 33     op.reduceat(array, inv_idx, axis=axis, dtype=dtype, out=out)
     34 else:
     35     out[..., uniques] = op.reduceat(array, inv_idx, axis=axis, dtype=dtype)

TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(<ufunc 'add'>, 'reduceat', <Quantity([0], 'kelvin')>, array([0]), axis=-1, dtype=dtype('int64'), out=(array([0]),)): 'Quantity', 'ndarray', 'ndarray'
dcherian commented 2 years ago

OK that wont work but it should have not gone down this code path at all in xarray. But it looks like I only tested pure pint arrays not pint + dask: https://github.com/pydata/xarray/blob/50ea159bfd0872635ebf4281e741f3c87f0bef6b/xarray/core/utils.py#L980

It'd be nice to add full pint support here but it'll be a bit of effort. Are you interested in working on it?

riley-brady commented 1 year ago

@dcherian sorry for the delay here. I could work on this effort, but unfortunately only on weekends, so it might be a long process. I would appreciate some guidance if you have some time (either over chat here or a zoom call) on which parts of the code to target, since I haven't worked closely with the package.

My current solution is to dequantify, run resample() or whichever other method this is happening on, and then quantify, which isn't ideal but works. The error message is not super clear, so I'm not sure that's a sustainable solution for the community as a whole.

dcherian commented 1 year ago

Thanks for offering to help @riley-brady

My current solution is to dequantify, run resample() or whichever other method this is happening on, and then quantify, which isn't ideal but works.

I think this is what we'll have to do since pint's support for ufuncs isn't great apparently.

lets strip array units if any right at the beginning, and reapply it at the end https://github.com/xarray-contrib/flox/blob/e3ea0e75a19c867aa6d6858cdf94810c0741a74b/flox/core.py#L1641

  1. We won't be handling units on by but I think that's OK for now? Alternatively you could again dequantify and then quantify after compute.
  2. To figure out output units, we'll have to run getattr(numpy, agg.name)(Quantity([1, 1,], dtype=array.dtype, units=array.units) So basically run the aggregation on a small problem to determine what the output units are (necessary for any, all, var, arg* for e.g.), and apply that at the end. This approach won't work for "custom aggregations" but we can deal with that later when we need to.