xarray-contrib / flox

Fast & furious GroupBy operations for dask.array
https://flox.readthedocs.io
Apache License 2.0
123 stars 16 forks source link

"most common" Aggregator with Dask #263

Closed BSchilperoort closed 3 days ago

BSchilperoort commented 1 year ago

Hi! Thanks for developing Flox, it has been quite useful in our workflows.

We have been working on a method to reduce categorical data using flox, using a "most common" strategy.

    most_common = Aggregation(
        name="most_common",
        numpy=_custom_grouped_reduction,
        chunk=None,
        combine=None,
    )

    result = flox.xarray.xarray_reduce(
        data,
        *coords,
        func=most_common,
        expected_groups=bounds,
    )

This works very well for in-memory datasets, however, not on Dask data (and I have not been able to get it to work with Dask data).

If the input is dask data, the following error is returned

NotImplementedError: Aggregation 'most_common' is only implemented for dask arrays when method='blockwise'.

If method="blockwise" is set, then the following error is returned.

InvalidIndexError: Reindexing only valid with uniquely valued Index objects

I have tried to implement chunk= and combine=, but I don't think that is possible with this method.

Memory limit of xarray_reduce

Additionally, because of the size of our input dataset the group_idx array does not fit in memory. This is due to the size being dim1 * dim2 * ... * size(np.int64).

Our current workaround is to apply the reduction on subset slices of the source data and target data, and then merging them. This works OK but is quite ugly and slow.

Is there another way to avoid this limit?


Most common implementation ```py def _most_common_label(neighbors: np.ndarray) -> int: """Find the most common label in a neighborhood. Note that if more than one labels have the same frequency which is the highest, then the first label in the list will be picked. """ unique_labels, counts = np.unique(neighbors, return_counts=True) return unique_labels[np.argmax(counts)] def _custom_grouped_reduction( group_idx: np.ndarray, array: np.ndarray, *, axis: int = -1, size: int | None = None, fill_value=None, dtype=None ) -> np.ndarray: """Custom grouped reduction for flox.Aggregation to get most common label. Args: group_idx : integer codes for group labels (1D) array : values to reduce (nD) axis : axis of array along which to reduce. Requires array.shape[axis] == len(group_idx) size : expected number of groups. If none, output.shape[-1] == number of uniques in group_idx fill_value : fill_value for when number groups in group_idx is less than size dtype : dtype of output Returns: np.ndarray with array.shape[-1] == size, containing a single value per group """ return npg.aggregate_numpy.aggregate( group_idx, array, func=_most_common_label, axis=axis, size=size, fill_value=fill_value, dtype=dtype, ) ```
dcherian commented 1 year ago

Ah, now this is interesting!

The general solution is hard and approximate (I think). You'll need to implement something like count min sketch (I'm sure there are others)

The easier way: You'll have to decompose the problem and compute each unique item and associated count in chunk (like you already do), merge those intermediates together in combine and then pick the top-most in finalize.

Are there useful properties for the field you're collapsing or the field you're grouping by (example plots of both these fields would be useful)? For examlke

BSchilperoort commented 1 year ago

Hi Deepak, thank you for your reply.

The easier way: You'll have to decompose the problem and compute each unique item and associated count in chunk (like you already do), merge those intermediates together in combine and then pick the top-most in finalize.

I attempted to do this, however I am unsure of how to exactly make this work. Below is the code, but I am not sure of how to wrap the individual functions in such a way that they're compatible with Aggregation.

I could not make much sense of the documentation/docstrings beyond the very simple examples which are available. I also received some errors relating to setting a element with a sequence (as each chunk will return an array with unique values, instead of a single value).

Code ```py def unique_labels(a: np.ndarray) -> np.ndarray: labels = np.unique(a) return labels def unique_counts(a: np.ndarray) -> np.ndarray: _, counts = np.unique(a, return_counts=True) return counts def most_common_chunked(multi_values: np.ndarray, multi_counts: np.ndarray, **kwargs): all_values, index = np.unique(multi_values, return_inverse=True) all_counts = np.zeros(all_values.size, np.int64) np.add.at(all_counts, index, multi_counts.ravel()) # inplace return all_values[all_counts.argmax()] most_common = Aggregation( name="most_common", numpy=_custom_grouped_reduction, chunk=(unique_labels, unique_counts), # first compute blockwise combine=(wrap_stack, wrap_stack), # stack these intermediate results finalize=most_common_chunked, # get most common value from the combined result fill_value=0, ) ```

Are there useful properties for the field you're collapsing or the field you're grouping by (example plots of both these fields would be useful)? For example are there only a small number of unique labels?

Our specific use case is a high resolution land cover dataset, which we want to be able to regrid easily. Of course once the "most common" strategy works, knowing the 2nd (or n-th) most common is also interesting.

Due to the nature of our dataset it will have a limited number of unique labels.

dcherian commented 1 year ago

Unfortunately, this will need some major thinking. You'll have to handle the unique and count intermediates together, similar to how the argreduction is run. This is not trivial.

Is it possible to rechunk so that a blockwise solution works? Can you describe your problem precisely? A reproducible example would be best...

Due to the nature of our dataset it will have a limited number of unique labels.

OK that's good for the exact nature of this solution.

dcherian commented 1 year ago

Also np.unique flattens the array. Is that OK for your purposes?

dcherian commented 12 months ago

269 should fix your blockwise problems. I'm thinking I can just add support for mode applied blockwise. Would that solve your use case?

BSchilperoort commented 12 months ago

Hi Deepak, thank you for your replies and the modifications to the code. I am currently too busy with other projects to focus on this, but I will get back to it and try it out as soon as I can make some time!


Also np.unique flattens the array. Is that OK for your purposes?

Yes that is fine.

A reproducible example would be best...

I do have a demo notebook on our repository where I show a basic use case. It is not easily reproducible as it requires a very large dataset at the moment, but it might give you a better idea of what we are trying to achieve.

dcherian commented 5 months ago

Ah I've been dense. We can simply compute the histogram and then get the mode:

import flox
import numpy as np
import pandas as pd

array = np.array([[0, 0, 0, 1, 1, 5, 4, 5, 5, 3], [0, 0, 0, 1, 1, 3, 4, 5, 5, 3]])
array[1, :] += 4
by = np.array([0, 0, 0, 1, 1, 2, 2, 2, 3, 3])

The default only works with numpy and problems that can be executed "blockwise"

flox.groupby_reduce(array, by, func="mode", axis=-1)
(array([[0, 1, 5, 3],
        [4, 5, 7, 7]]),
 array([0, 1, 2, 3]))

This is the more general implementation that should always work regardless of chunking

result, uniques, groups = flox.groupby_reduce(array, array, np.broadcast_to(by, array.shape), func="count", fill_value=0, axis=-1)
uniques[result.argmax(axis=-2)], groups
(array([[0, 1, 5, 3],
        [4, 5, 7, 7]]),
 array([0, 1, 2, 3]))

With dask, you'll need to provide the unique values of array in expected_groups (this could be relaxed).

dcherian commented 3 weeks ago

Responding to https://github.com/xarray-contrib/flox/pull/391#issuecomment-2333640949

Here, estimate the histogram manually and find the mode. You'll need to tell it what category values you expect here np.arange(0,6)).

import pandas as pd

histogram = flox.xarray.xarray_reduce(
    xr.ones_like(data, dtype=bool),
    data.astype(int),  # important, needs to be int
    *coords,
    func="count",
    expected_groups=(pd.Index(np.arange(0, 6)), *bounds),
    fill_value=-1,
)
mode = histogram.idxmax("soil_type")

I think this is right.

In theory, we could enable passing just the number of expected categories, and discover the values at compute time (https://github.com/xarray-contrib/flox/issues/15), but that's a bit more complexity.

BSchilperoort commented 2 weeks ago

Here, estimate the histogram manually and find the mode.

Nice, this works and seems to have not too many issues being computed lazily!

The memory usage of the delayed objects does seem to be quite high though. In the ARCO ERA5 example, a time dim size of 700k makes Python use 4GB of ram (2GB for the output of flox.xarray.xarray_reduce, another 2GB for the object produced by histogram.idxmax("soil_type")). It scales strongly wit the size of the time dimension, but not at all with the size of the latitude & longitude bins :thinking:

In theory, we could enable passing just the number of expected categories, and discover the values at compute time (https://github.com/xarray-contrib/flox/issues/15), but that's a bit more complexity.

I think it's fine to have users specify the expected categories. That's no problem for this use-case.

dcherian commented 2 weeks ago

Ah I made a mistake: you'll need to specify dim=("latitude", "longitude"), to preserve the time dimension. Right now it will reduce across time too which would explain your memory usage comment

BSchilperoort commented 2 weeks ago

Ah I made a mistake: you'll need to specify dim=("latitude", "longitude"), to preserve the time dimension. Right now it will reduce across time too which would explain your memory usage comment

I did already modify your suggestion by adding the dim kwarg. There still seems to be a high memory use depending on the size of the time dimension though. With some more testing, it seems to be linear based on the size of the time dim (even though the aggregation is only taking place over latitude and longitude).

Here's a new gist: https://gist.github.com/BSchilperoort/134e625fe5a67805883bed9f93aa72b2

Need me to open an issue at flox?

dcherian commented 3 days ago

This was fixed in xarray-regrid. And I think xarray can expose the same thing as histogram