pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.56k stars 1.07k forks source link

Allow grouping by dask variables #2852

Open jmichel-otb opened 5 years ago

jmichel-otb commented 5 years ago

Code Sample, a copy-pastable example if possible

I am using xarray in combination to dask distributed on a cluster, so a mimimal code sample demonstrating my problem is not easy to come up with.

Problem description

Here is what I observe:

  1. In my environment, dask distributed is correctly set-up with auto-scaling. I can verify this by loading data into xarray and using aggregation functions like mean(). This triggers auto-scaling and the dask dashboard shows that the processing is spread accross slave nodes.

  2. I have the following xarray dataset called geoms_ds:

    
    <xarray.Dataset>
    Dimensions:  (x: 10980, y: 10980)
    Coordinates:
    * y        (y) float64 4.9e+06 4.9e+06 4.9e+06 ... 4.79e+06 4.79e+06 4.79e+06
    * x        (x) float64 3e+05 3e+05 3e+05 ... 4.098e+05 4.098e+05 4.098e+05
    Data variables:
    label    (y, x) uint16 dask.array<shape=(10980, 10980), chunksize=(200, 10980)>
Which I load with the following code sample:

```python
import xarray as xr
geoms = xr.open_rasterio('test_rasterization_T31TCJ_uint16.tif',chunks={'band': 1, 'x': 10980, 'y': 200})
geoms_squeez = geoms.isel(band=0).squeeze().drop(labels='band')
geoms_ds = geoms_squeez.to_dataset(name='label')

This array holds a finite number of integer values denoting groups (or classes if you like). I would like to perform statistics on groups (with additional variables) such as the mean value of a given variable for each group for instance.

  1. I can do this perfectly for a single group using .where(label=xxx).mean('variable'), this behaves as expected, triggering auto-scaling and dask graph of task.

  2. The problem is that I have a lot of groups (or classes) and looping through all of them and apply where() is not very efficient. From my reading of xarray documentation, groupby is what I need, to perform stats on all groups at once.

  3. When I try to use geoms_ds.groupby('label').size() for instance, here is what I observe:

    • Grouping is not lazy, it is evaluated immediately,
    • Grouping is not performed through dask distributed, only the master node is working, on a single thread,
    • The grouping operation takes a large amount of time and eats a large amount of memory (nearly 30 Gb, which is a lot more than what is required to store the full dataset in memory)
    • Most of the time, the grouping fail with the following errors and warnings:
distributed.utils_perf - WARNING - full garbage collections took 52% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 47% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 48% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 50% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 53% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 56% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 56% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 58% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 58% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 59% CPU time recently (threshold: 10%)
WARNING:dask_jobqueue.core:Worker tcp://10.135.39.92:51747 restart in Job 2758934. This can be due to memory issue.
distributed.utils - ERROR - 'tcp://10.135.39.92:51747'
Traceback (most recent call last):
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/distributed/utils.py", line 648, in log_errors
    yield
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/distributed/scheduler.py", line 1360, in add_worker
    yield self.handle_worker(comm=comm, worker=address)
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/tornado/gen.py", line 1133, in run
    value = future.result()
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/distributed/scheduler.py", line 2220, in handle_worker
    worker_comm = self.stream_comms[worker]
KeyError: ...

Which I assume comes from the fact that the process is killed by pbs for excessive memory usage.

Expected Output

I would except the following:

Output of xr.show_versions()

NSTALLED VERSIONS ------------------ commit: None python: 3.6.7 | packaged by conda-forge | (default, Nov 21 2018, 03:09:43) [GCC 7.3.0] python-bits: 64 OS: Linux OS-release: 3.10.0-327.el7.x86_64 machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: en_US.UTF-8 libhdf5: 1.10.4 libnetcdf: 4.6.2 xarray: 0.11.3 pandas: 0.24.1 numpy: 1.16.1 scipy: 1.2.0 netCDF4: 1.4.2 pydap: None h5netcdf: None h5py: None Nio: None zarr: None cftime: 1.0.3.4 PseudonetCDF: None rasterio: 1.0.15 cfgrib: None iris: None bottleneck: None cyordereddict: None dask: 1.1.1 distributed: 1.25.3 matplotlib: 3.0.2 cartopy: 0.17.0 seaborn: 0.9.0 setuptools: 40.7.1 pip: 19.0.1 conda: None pytest: None IPython: 7.1.1 sphinx: None
rabernat commented 5 years ago
 label    (y, x) uint16 dask.array<shape=(10980, 10980), chunksize=(200, 10980)>
...
geoms_ds.groupby('label')`

It is very hard to make this sort of groupby lazy, because you are grouping over the variable label itself. Groupby uses a split-apply-combine paradigm to transform the data. The apply and combine steps can be lazy. But the split step cannot. Xarray uses the group variable to determine how to index the array, i.e. which items belong in which group. To do this, it needs to read the whole variable into memory.

In this specific example, it sounds like what you want is to compute the histogram of labels. That could be accomplished without groupby. For example, you could use apply_ufunc together with dask.array.histogram.

So my recommendation is to think of a way to accomplish what you want that does not involve groupby.

shoyer commented 5 years ago

The current design of GroupBy.apply() in xarray is entirely ignorant of dask: it simply uses a for loop over the grouped variable to built up a computation with high level array operations.

This makes operations that group over large keys stored in dask inefficient. This could be done efficiently (dask.dataframe does this, and might be worth trying in your case) but it's a more challenging distributed computing problem, and xarray's current data model would not know how large of a dimension to create for the returned ararys (doing this properly would require supporting arrays with unknown dimension sizes).

jmichel-otb commented 5 years ago

Many thanks for your answers @shoyer and @rabernat .

I am relatively new to xarray and dask, I am trying to determine if it can fit our need for analysis of large stacks of Sentinel data on our cluster.

I will give a try to dask.array.histogram ass @rabernat suggested.

I also had the following idea. Given that:

I do not actually need the discovery of unique labels that groupby() performs, what I really need is an efficient way to perform multiple where() aggregate operations at once, to avoid traversing the data multiple time.

Maybe there is already something like that in xarray, or maybe this is something I can derive from the implementation of where() ?

dcherian commented 5 years ago

It sounds like there is an apply_ufunc solution to your problem but I dont know how to write it! ;)

shoyer commented 5 years ago

Roughly how many unique labels do you have?

jmichel-otb commented 5 years ago

That's a tough question ;) In the current dataset I have 950 unique labels, but in my use cases it can be be a lot more (e.g. agricultaral crops) or a lot less (adminstrative boundaries or regions).

C-H-Simpson commented 4 years ago

I'm going to share a code snippet that might be useful to people reading this issue. I wanted to group my data by month and year, and take the mean for each group.

I did not want to use resample, as I wanted the dimensions to be ('month', 'year'), rather than ('time'). The obvious way of doing this is to use a pd.MultiIndex to create a 'year_month' stacked coordinate: I found this did not have good perfomance.

My solution was to use xr.apply_ufunc, as suggested above. I think it should be OK with dask chunked data, provided it is not chunked in time.

Here is the code:

def _grouped_mean(
            data: np.ndarray,
            months: np.ndarray,
            years: np.ndarray) -> np.ndarray:
        """similar to grouping year_month MultiIndex, but faster.

        Should be used wrapped by _wrapped_grouped_mean"""
        unique_months = np.sort(np.unique(months))
        unique_years = np.sort(np.unique(years))
        old_shape = list(data.shape)
        new_shape = old_shape[:-1]
        new_shape.append(unique_months.shape[0])
        new_shape.append(unique_years.shape[0])

        output = np.zeros(new_shape)

        for i_month, j_year in np.ndindex(output.shape[2:]):
            indices = np.intersect1d(
                (months == unique_months[i_month]).nonzero(),
                (years == unique_years[j_year]).nonzero()
            )

            output[:, :, i_month, j_year] =\
                np.mean(data[:, :, indices], axis=-1)

        return output

def _wrapped_grouped_mean(da: xr.DataArray) -> xr.DataArray:
        """similar to grouping by a year_month MultiIndex, but faster.

        Wraps a numpy-style function with xr.apply_ufunc
        """
        Y = xr.apply_ufunc(
            _grouped_mean,
            da,
            da.time.dt.month,
            da.time.dt.year,
            input_core_dims=[['lat', 'lon', 'time'], ['time'], ['time']],
            output_core_dims=[['lat', 'lon', 'month', 'year']],
        )
        Y = Y.assign_coords(
            {'month': np.sort(np.unique(da.time.dt.month)),
             'year': np.sort(np.unique(da.time.dt.year))})
        return Y
rabernat commented 4 years ago

👀 cc @chiaral

stale[bot] commented 2 years ago

In order to maintain a list of currently relevant issues, we mark issues as stale after a period of inactivity

If this issue remains relevant, please comment here or remove the stale label; otherwise it will be marked as closed automatically

dcherian commented 2 years ago

You can do this with flox now. Eventually we can update xarray to support grouping by a dask variable.

The limitation will be that the user will have to provide "expected groups" so that we can construct the output coordinate.

riley-brady commented 3 days ago

Bringing in a related MVE from another thread with @dcherian on https://github.com/xarray-contrib/flox/issues/398.

Here's an example comparing a high-resolution dummy dataset between flox and xarray.GroupBy(). Trying to implicitly run UniqueGrouper() on my grid with 18 unique integers is crashing the cluster due to the underlying np.unique() call. Meanwhile, using flox.xarray.xarray_reduce with expected_groups can handle this whole aggregation in just a few seconds.

At least in this example, the expected_groups required kwarg is very minimal headache since I know the confines of my integer mask grid.

import flox.xarray
import xarray as xr
import numpy as np
import dask.array as da

np.random.seed(123)

# Simulating 1km global grid
lat = np.linspace(-89.1, 89.1, 21384)
lon = np.linspace(-180, 180, 43200)

# Simulating data we'll be aggregating
data = da.random.random((lat.size, lon.size), chunks=(3600, 3600))
data = xr.DataArray(data, dims=['lat', 'lon'], coords={'lat': lat, 'lon': lon})

# Simulating 18 unique groups on the grid to aggregate over
integer_mask = da.random.choice(np.arange(1, 19), size=(lat.size, lon.size), chunks=(3600, 3600))
integer_mask = xr.DataArray(integer_mask, dims=['lat', 'lon'], coords={'lat': lat, 'lon': lon})

# Add as coordinate
data = data.assign_coords(dict(label1=integer_mask))

# Try with groupby (usually will spike scheduler memory, crash cluster, etc.). Haven't done a lot
# of looking at what's going on to wreck the cluster, just get impatient and give up.
# gb = data.groupby("label1")

# Versus, with expected groups. Runs extremely quickly to set up graph + execute.
res = flox.xarray.xarray_reduce(data, "label1", func="mean", skipna=True, expected_groups=np.arange(1, 19))
dcherian commented 2 days ago

Fixed in #9522 for reductions with flox. Everything else will fail :)

@bradyrx your example takes 15-20s to set up on my machine due to some useless stacking of dimensions, that we don't need to do. Something to fix in the future...

riley-brady commented 2 days ago

Awesome, @dcherian , thanks for jumping on this!! Looks like a long-time issue that needed a nice MVE and some more push. I can also git checkout your branch and run with my cluster setup for comparison. Might not be til early next week.