Closed phofl closed 1 week ago
This can't work in general.
We switched from map_overlap to cumreduction intentionally: https://github.com/pydata/xarray/pull/6118 . I fully support improving cumreduction
. It's a fundamental parallel primitive.
Oh I didn't consider None as limit, sorry about that
You could experiment with method="blelloch"
vs method="sequential"
in cumreduction
.
I thought more about this, I didn't consider large limit values properly. cumreduction
itself works ok-ish, the issue that makes the task-graph that large is the section here:
if n is not None and 0 < n < array.shape[axis] - 1:
arange = da.broadcast_to(
da.arange(
array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
).reshape(
tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
),
array.shape,
array.chunks,
)
valid_arange = da.where(da.notnull(array), arange, np.nan)
valid_limits = (arange - push(valid_arange, None, axis)) <= n
# omit the forward fill that violate the limit
return da.where(valid_limits, push(array, None, axis), np.nan)
Bunch of operations that can't be fused properly because of the interactions between the 2 different arrays. I'll think a bit more if we can reduce this down somehow, but there isn't anything obvious right away (at least not to me).
cumreduction
itself is equivalent to map_overlap
from a topology perspective if the overlapping part only reaches a single neighbouring chunk, would you be open to calling overlap in these cases? Makes it a bit uglier, but that seems to be a reasonably common use-case from what I have seen so far (I might be totally wrong here)?
Can you add push
to dask.array
? Then you can add whatever optimization you want :). We'd be happy to dispatch instead of vendoring all this if we can.
Also, you should be able to write this as a single cumreduction
that takes a 1D array of axis-indices and the input array as inputs. I wrote it for grouped ffill in flox: https://github.com/xarray-contrib/flox/blob/672be8ceeebfa588a15ebdc9861999efa12fa44e/flox/aggregations.py#L651
Oh, I'll check that one out.
Sure, there shouldn't be any reason not to add this. I'll check out the flox implementation, getting the number of tasks down would be nice, but adding it in Dask should be a good option anyway
Ah I guess the core issue is a dependency on numbagg
and/or bottleneck
.
That should be fine, we can just raise if neither is installed, similar to what you are doing here
Hi,
When I did that code I tried to preserve as much as possible the map_overlap because it is, in theory, faster, but the main issue is that it is not efficient to use it in the general case because it is always possible that more than one chunk is empty making that the push operation does not reach the next chunks, or if you want to make it work for the general case you would have to run it N - 1 times, being N the number of chunks.
If the main issue is the code that I did to apply a limit on the forward fill generates many tasks, I think that you could replace it with a kind of cumulative reset logic (I would really love to have a bitmask parameter on the cumulative operations that allows restarting the accumulate value), this operation would accumulate a count based on the number of contiguous nan, and if it finds a nonnan then it restart the count to 0, please see the code attached, I think it illustrates the idea and it is also a functional code, the only issue is that it requires numba and I'm not sure if you want to force to the user to have numba.
def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
import dask.array as da
import numpy as np
from xarray.core.duck_array_ops import _push
def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
return np.where(~np.isnan(b), b, a)
# The method parameter makes that the tests for python 3.7 fails.
pushed_array = da.reductions.cumreduction(
func=_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
axis=axis,
dtype=array.dtype,
)
if n is not None and 0 < n < array.shape[axis] - 1:
import numba as nb
@nb.vectorize([nb.int64(nb.int64, nb.int64)])
def reset_cumsum(x, y):
return x + y if y else y
def combine_reset_cumsum(x, y):
# Sum the previous result (X) only to the positions before the first zero
bitmask = np.cumprod(y != 0, axis=axis)
return np.where(bitmask, y + x, y)
valid_positions = da.reductions.cumreduction(
func=reset_cumsum.accumulate,
binop=combine_reset_cumsum,
ident=0,
x=da.isnan(array).astype(int),
axis=axis,
dtype=int,
) <= n
pushed_array = da.where(valid_positions, pushed_array, np.nan)
return pushed_array
(I would really love to have a bitmask parameter on the cumulative operations that allows restarting the accumulate value)
See https://github.com/pydata/xarray/issues/9229. You can reset the accumulated value at the group edges.
this operation would accumulate a count based on the number of contiguous nan, and if it finds a nonnan then it restart the count to 0,
yes rpetty sure you can accumulate both the values and the count in a single scan, and then add a blockwise task that applies the mask.
I have some time waiting for that segmented scan functionality, it will be quite useful, and it could be used to replace the function that I did with numba to reset the cumulative sum, but if someone need it I can implement the same function that I did with numba using only numpy and sent a PR, as it is, it should reduce the number of tasks in half.
Hi @phofl, I made the following improvements to the code to reduce the number of tasks produced by the limit parameter and avoid using Numba. I hope this helps to solve the performance issue that you are facing. I can send a PR to Dask if you need it (I saw that a push method was added there).
def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
import dask.array as da
import numpy as np
from xarray.core.duck_array_ops import _push
def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so,
# the only missing part is filling the missing values using the
# last data of the previous chunk
return np.where(np.isnan(b), a, b)
# The method parameter makes that the tests for python 3.7 fails.
pushed_array = da.reductions.cumreduction(
func=_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
axis=axis,
dtype=array.dtype,
)
if n is not None and 0 < n < array.shape[axis] - 1:
def reset_cumsum(x, axis):
cumsum = np.cumsum(x, axis=axis)
reset_points = np.maximum.accumulate(
np.where(x == 0, cumsum, 0), axis=axis
)
return cumsum - reset_points
def combine_reset_cumsum(x, y):
bitmask = np.cumprod(y != 0, axis=axis)
return np.where(bitmask, y + x, y)
valid_positions = da.reductions.cumreduction(
func=reset_cumsum,
binop=combine_reset_cumsum,
ident=0,
x=da.isnan(array).astype(int),
axis=axis,
dtype=int,
) <= n
pushed_array = da.where(valid_positions, pushed_array, np.nan)
return pushed_array
Hah very nice work @josephnowak :clap: :clap: :clap:
it'd be nice to get this to work with method="blelloch"
which is the standard work-efficient scan algorithm.
The last time that I tried to use that method the tests of Xarray failed for Python 3.7, but I think that Xarray already dropped support for that version, so probably we can use it without any issue, I'm going to send a PR soon. Also, take into consideration that this is written on the Dask documentation, not sure why no one has done a proper benchmark.
Nice!
I think it makes sense to add this to both repositories for now since xarray can only dispatch to Dask if the dask version is 2024.11 or newer
whats-new.rst
api.rst
Our benchmarks here showed us that ffill alone adds 4.5 million tasks to the graph which isn't great (the dataset has 550k chunks, so a multiplication of 9).
Rewriting this with map_overlap gets this down to 1.5 million tasks, which is basically the number of chunks times 3, which is the minimum that we can get to at the moment.
We merged a few map_overlap improvements today on the dask side to make this possible, but it's now a nice improvement (also makes code on the xarray side easier).
cc @dcherian