Open jmichel-otb opened 5 years ago
label (y, x) uint16 dask.array<shape=(10980, 10980), chunksize=(200, 10980)>
...
geoms_ds.groupby('label')`
It is very hard to make this sort of groupby lazy, because you are grouping over the variable label
itself. Groupby uses a split-apply-combine paradigm to transform the data. The apply and combine steps can be lazy. But the split step cannot. Xarray uses the group variable to determine how to index the array, i.e. which items belong in which group. To do this, it needs to read the whole variable into memory.
In this specific example, it sounds like what you want is to compute the histogram of labels. That could be accomplished without groupby. For example, you could use apply_ufunc together with dask.array.histogram
.
So my recommendation is to think of a way to accomplish what you want that does not involve groupby.
The current design of GroupBy.apply()
in xarray is entirely ignorant of dask: it simply uses a for
loop over the grouped variable to built up a computation with high level array operations.
This makes operations that group over large keys stored in dask inefficient. This could be done efficiently (dask.dataframe
does this, and might be worth trying in your case) but it's a more challenging distributed computing problem, and xarray's current data model would not know how large of a dimension to create for the returned ararys (doing this properly would require supporting arrays with unknown dimension sizes).
Many thanks for your answers @shoyer and @rabernat .
I am relatively new to xarray
and dask
, I am trying to determine if it can fit our need for analysis of large stacks of Sentinel data on our cluster.
I will give a try to dask.array.histogram
ass @rabernat suggested.
I also had the following idea. Given that:
.where(label=xxx).mean('variable')
does the job perfectly for one label,I do not actually need the discovery of unique labels that groupby()
performs, what I really need is an efficient way to perform multiple where()
aggregate operations at once, to avoid traversing the data multiple time.
Maybe there is already something like that in xarray, or maybe this is something I can derive from the implementation of where()
?
It sounds like there is an apply_ufunc solution to your problem but I dont know how to write it! ;)
Roughly how many unique labels do you have?
That's a tough question ;) In the current dataset I have 950 unique labels, but in my use cases it can be be a lot more (e.g. agricultaral crops) or a lot less (adminstrative boundaries or regions).
I'm going to share a code snippet that might be useful to people reading this issue. I wanted to group my data by month and year, and take the mean for each group.
I did not want to use resample
, as I wanted the dimensions to be ('month', 'year'), rather than ('time'). The obvious way of doing this is to use a pd.MultiIndex to create a 'year_month' stacked coordinate: I found this did not have good perfomance.
My solution was to use xr.apply_ufunc
, as suggested above. I think it should be OK with dask chunked data, provided it is not chunked in time.
Here is the code:
def _grouped_mean(
data: np.ndarray,
months: np.ndarray,
years: np.ndarray) -> np.ndarray:
"""similar to grouping year_month MultiIndex, but faster.
Should be used wrapped by _wrapped_grouped_mean"""
unique_months = np.sort(np.unique(months))
unique_years = np.sort(np.unique(years))
old_shape = list(data.shape)
new_shape = old_shape[:-1]
new_shape.append(unique_months.shape[0])
new_shape.append(unique_years.shape[0])
output = np.zeros(new_shape)
for i_month, j_year in np.ndindex(output.shape[2:]):
indices = np.intersect1d(
(months == unique_months[i_month]).nonzero(),
(years == unique_years[j_year]).nonzero()
)
output[:, :, i_month, j_year] =\
np.mean(data[:, :, indices], axis=-1)
return output
def _wrapped_grouped_mean(da: xr.DataArray) -> xr.DataArray:
"""similar to grouping by a year_month MultiIndex, but faster.
Wraps a numpy-style function with xr.apply_ufunc
"""
Y = xr.apply_ufunc(
_grouped_mean,
da,
da.time.dt.month,
da.time.dt.year,
input_core_dims=[['lat', 'lon', 'time'], ['time'], ['time']],
output_core_dims=[['lat', 'lon', 'month', 'year']],
)
Y = Y.assign_coords(
{'month': np.sort(np.unique(da.time.dt.month)),
'year': np.sort(np.unique(da.time.dt.year))})
return Y
👀 cc @chiaral
In order to maintain a list of currently relevant issues, we mark issues as stale after a period of inactivity
If this issue remains relevant, please comment here or remove the stale
label; otherwise it will be marked as closed automatically
You can do this with flox now. Eventually we can update xarray to support grouping by a dask variable.
The limitation will be that the user will have to provide "expected groups" so that we can construct the output coordinate.
Bringing in a related MVE from another thread with @dcherian on https://github.com/xarray-contrib/flox/issues/398.
Here's an example comparing a high-resolution dummy dataset between flox
and xarray.GroupBy()
. Trying to implicitly run UniqueGrouper()
on my grid with 18 unique integers is crashing the cluster due to the underlying np.unique()
call. Meanwhile, using flox.xarray.xarray_reduce
with expected_groups
can handle this whole aggregation in just a few seconds.
At least in this example, the expected_groups
required kwarg is very minimal headache since I know the confines of my integer mask grid.
import flox.xarray
import xarray as xr
import numpy as np
import dask.array as da
np.random.seed(123)
# Simulating 1km global grid
lat = np.linspace(-89.1, 89.1, 21384)
lon = np.linspace(-180, 180, 43200)
# Simulating data we'll be aggregating
data = da.random.random((lat.size, lon.size), chunks=(3600, 3600))
data = xr.DataArray(data, dims=['lat', 'lon'], coords={'lat': lat, 'lon': lon})
# Simulating 18 unique groups on the grid to aggregate over
integer_mask = da.random.choice(np.arange(1, 19), size=(lat.size, lon.size), chunks=(3600, 3600))
integer_mask = xr.DataArray(integer_mask, dims=['lat', 'lon'], coords={'lat': lat, 'lon': lon})
# Add as coordinate
data = data.assign_coords(dict(label1=integer_mask))
# Try with groupby (usually will spike scheduler memory, crash cluster, etc.). Haven't done a lot
# of looking at what's going on to wreck the cluster, just get impatient and give up.
# gb = data.groupby("label1")
# Versus, with expected groups. Runs extremely quickly to set up graph + execute.
res = flox.xarray.xarray_reduce(data, "label1", func="mean", skipna=True, expected_groups=np.arange(1, 19))
Fixed in #9522 for reductions with flox. Everything else will fail :)
@bradyrx your example takes 15-20s to set up on my machine due to some useless stacking of dimensions, that we don't need to do. Something to fix in the future...
Awesome, @dcherian , thanks for jumping on this!! Looks like a long-time issue that needed a nice MVE and some more push. I can also git checkout your branch and run with my cluster setup for comparison. Might not be til early next week.
Code Sample, a copy-pastable example if possible
I am using
xarray
in combination todask distributed
on a cluster, so a mimimal code sample demonstrating my problem is not easy to come up with.Problem description
Here is what I observe:
In my environment,
dask distributed
is correctly set-up with auto-scaling. I can verify this by loading data intoxarray
and using aggregation functions likemean()
. This triggers auto-scaling and the dask dashboard shows that the processing is spread accross slave nodes.I have the following
xarray
dataset calledgeoms_ds
:This
array
holds a finite number of integer values denoting groups (or classes if you like). I would like to perform statistics on groups (with additional variables) such as the mean value of a given variable for each group for instance.I can do this perfectly for a single group using
.where(label=xxx).mean('variable')
, this behaves as expected, triggering auto-scaling and dask graph of task.The problem is that I have a lot of groups (or classes) and looping through all of them and apply
where()
is not very efficient. From my reading ofxarray
documentation,groupby
is what I need, to perform stats on all groups at once.When I try to use
geoms_ds.groupby('label').size()
for instance, here is what I observe:Which I assume comes from the fact that the process is killed by pbs for excessive memory usage.
Expected Output
I would except the following:
groupby
lazily evaluated,dask distributed
Output of
xr.show_versions()