pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.63k stars 1.09k forks source link

Replace push implementation with map_overlap for Dask #9712

Closed phofl closed 1 week ago

phofl commented 2 weeks ago

Our benchmarks here showed us that ffill alone adds 4.5 million tasks to the graph which isn't great (the dataset has 550k chunks, so a multiplication of 9).

Rewriting this with map_overlap gets this down to 1.5 million tasks, which is basically the number of chunks times 3, which is the minimum that we can get to at the moment.

We merged a few map_overlap improvements today on the dask side to make this possible, but it's now a nice improvement (also makes code on the xarray side easier).

cc @dcherian

dcherian commented 2 weeks ago

This can't work in general.

We switched from map_overlap to cumreduction intentionally: https://github.com/pydata/xarray/pull/6118 . I fully support improving cumreduction. It's a fundamental parallel primitive.

dcherian commented 2 weeks ago

See this test: https://github.com/pydata/xarray/blob/a00bc919edffc1e5fb5703b2dcd6a4749acf2619/xarray/tests/test_duck_array_ops.py#L1008-L1028

phofl commented 2 weeks ago

Oh I didn't consider None as limit, sorry about that

dcherian commented 2 weeks ago

You could experiment with method="blelloch" vs method="sequential" in cumreduction.

phofl commented 2 weeks ago

I thought more about this, I didn't consider large limit values properly. cumreduction itself works ok-ish, the issue that makes the task-graph that large is the section here:


     if n is not None and 0 < n < array.shape[axis] - 1:
         arange = da.broadcast_to(
             da.arange(
                 array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
             ).reshape(
                 tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
             ),
             array.shape,
             array.chunks,
         )
         valid_arange = da.where(da.notnull(array), arange, np.nan)
         valid_limits = (arange - push(valid_arange, None, axis)) <= n
         # omit the forward fill that violate the limit
         return da.where(valid_limits, push(array, None, axis), np.nan)

Bunch of operations that can't be fused properly because of the interactions between the 2 different arrays. I'll think a bit more if we can reduce this down somehow, but there isn't anything obvious right away (at least not to me).

cumreduction itself is equivalent to map_overlap from a topology perspective if the overlapping part only reaches a single neighbouring chunk, would you be open to calling overlap in these cases? Makes it a bit uglier, but that seems to be a reasonably common use-case from what I have seen so far (I might be totally wrong here)?

dcherian commented 2 weeks ago

Can you add push to dask.array? Then you can add whatever optimization you want :). We'd be happy to dispatch instead of vendoring all this if we can.

Also, you should be able to write this as a single cumreduction that takes a 1D array of axis-indices and the input array as inputs. I wrote it for grouped ffill in flox: https://github.com/xarray-contrib/flox/blob/672be8ceeebfa588a15ebdc9861999efa12fa44e/flox/aggregations.py#L651

phofl commented 2 weeks ago

Oh, I'll check that one out.

Sure, there shouldn't be any reason not to add this. I'll check out the flox implementation, getting the number of tasks down would be nice, but adding it in Dask should be a good option anyway

dcherian commented 2 weeks ago

Ah I guess the core issue is a dependency on numbagg and/or bottleneck.

phofl commented 2 weeks ago

That should be fine, we can just raise if neither is installed, similar to what you are doing here

josephnowak commented 1 week ago

Hi,

When I did that code I tried to preserve as much as possible the map_overlap because it is, in theory, faster, but the main issue is that it is not efficient to use it in the general case because it is always possible that more than one chunk is empty making that the push operation does not reach the next chunks, or if you want to make it work for the general case you would have to run it N - 1 times, being N the number of chunks.

If the main issue is the code that I did to apply a limit on the forward fill generates many tasks, I think that you could replace it with a kind of cumulative reset logic (I would really love to have a bitmask parameter on the cumulative operations that allows restarting the accumulate value), this operation would accumulate a count based on the number of contiguous nan, and if it finds a nonnan then it restart the count to 0, please see the code attached, I think it illustrates the idea and it is also a functional code, the only issue is that it requires numba and I'm not sure if you want to force to the user to have numba.

def push(array, n, axis):
    """
    Dask-aware bottleneck.push
    """
    import dask.array as da
    import numpy as np

    from xarray.core.duck_array_ops import _push

    def _fill_with_last_one(a, b):
        # cumreduction apply the push func over all the blocks first so, the only missing part is filling
        # the missing values using the last data of the previous chunk
        return np.where(~np.isnan(b), b, a)

    # The method parameter makes that the tests for python 3.7 fails.
    pushed_array = da.reductions.cumreduction(
        func=_push,
        binop=_fill_with_last_one,
        ident=np.nan,
        x=array,
        axis=axis,
        dtype=array.dtype,
    )

    if n is not None and 0 < n < array.shape[axis] - 1:
        import numba as nb

        @nb.vectorize([nb.int64(nb.int64, nb.int64)])
        def reset_cumsum(x, y):
            return x + y if y else y

        def combine_reset_cumsum(x, y):
            # Sum the previous result (X) only to the positions before the first zero
            bitmask = np.cumprod(y != 0, axis=axis)
            return np.where(bitmask, y + x, y)

        valid_positions = da.reductions.cumreduction(
            func=reset_cumsum.accumulate,
            binop=combine_reset_cumsum,
            ident=0,
            x=da.isnan(array).astype(int),
            axis=axis,
            dtype=int,
        ) <= n
        pushed_array = da.where(valid_positions, pushed_array, np.nan)

    return pushed_array
dcherian commented 1 week ago

(I would really love to have a bitmask parameter on the cumulative operations that allows restarting the accumulate value)

See https://github.com/pydata/xarray/issues/9229. You can reset the accumulated value at the group edges.

this operation would accumulate a count based on the number of contiguous nan, and if it finds a nonnan then it restart the count to 0,

yes rpetty sure you can accumulate both the values and the count in a single scan, and then add a blockwise task that applies the mask.

josephnowak commented 1 week ago

I have some time waiting for that segmented scan functionality, it will be quite useful, and it could be used to replace the function that I did with numba to reset the cumulative sum, but if someone need it I can implement the same function that I did with numba using only numpy and sent a PR, as it is, it should reduce the number of tasks in half.

josephnowak commented 1 week ago

Hi @phofl, I made the following improvements to the code to reduce the number of tasks produced by the limit parameter and avoid using Numba. I hope this helps to solve the performance issue that you are facing. I can send a PR to Dask if you need it (I saw that a push method was added there).

def push(array, n, axis):
    """
    Dask-aware bottleneck.push
    """
    import dask.array as da
    import numpy as np

    from xarray.core.duck_array_ops import _push

    def _fill_with_last_one(a, b):
        # cumreduction apply the push func over all the blocks first so,
        # the only missing part is filling the missing values using the
        # last data of the previous chunk
        return np.where(np.isnan(b), a, b)

    # The method parameter makes that the tests for python 3.7 fails.
    pushed_array = da.reductions.cumreduction(
        func=_push,
        binop=_fill_with_last_one,
        ident=np.nan,
        x=array,
        axis=axis,
        dtype=array.dtype,
    )

    if n is not None and 0 < n < array.shape[axis] - 1:

        def reset_cumsum(x, axis):
            cumsum = np.cumsum(x, axis=axis)
            reset_points = np.maximum.accumulate(
                np.where(x == 0, cumsum, 0), axis=axis
            )
            return cumsum - reset_points

        def combine_reset_cumsum(x, y):
            bitmask = np.cumprod(y != 0, axis=axis)
            return np.where(bitmask, y + x, y)

        valid_positions = da.reductions.cumreduction(
            func=reset_cumsum,
            binop=combine_reset_cumsum,
            ident=0,
            x=da.isnan(array).astype(int),
            axis=axis,
            dtype=int,
        ) <= n
        pushed_array = da.where(valid_positions, pushed_array, np.nan)

    return pushed_array
dcherian commented 1 week ago

Hah very nice work @josephnowak :clap: :clap: :clap:

dcherian commented 1 week ago

it'd be nice to get this to work with method="blelloch" which is the standard work-efficient scan algorithm.

josephnowak commented 1 week ago

The last time that I tried to use that method the tests of Xarray failed for Python 3.7, but I think that Xarray already dropped support for that version, so probably we can use it without any issue, I'm going to send a PR soon. Also, take into consideration that this is written on the Dask documentation, not sure why no one has done a proper benchmark. image

phofl commented 1 week ago

Nice!

I think it makes sense to add this to both repositories for now since xarray can only dispatch to Dask if the dask version is 2024.11 or newer