pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.63k stars 1.09k forks source link

Optimize `ffill`, `bfill` with dask when `limit` is specified #9771

Closed josephnowak closed 1 week ago

josephnowak commented 1 week ago

This improvement comes from the discussion on this PR https://github.com/pydata/xarray/pull/9712

josephnowak commented 1 week ago

The tests that are failing look unrelated to the change of this PR 🤔

josephnowak commented 1 week ago

@dcherian do you think that is a good idea to add a keepdims parameter to the first and last function implemented on the duck_array_ops.py file? if I add that I could use them directly on the cumreduction with blelloch method, if not I can just create a dummy wrapper like the one that I did for the push method to receive a type parameter that is not use

dcherian commented 1 week ago

yes that sounds good to me.

dcherian commented 1 week ago

Great! Can you add a note to whats-new please? This is worth advertising.

josephnowak commented 1 week ago

Hi @dcherian

I think we could add this special case to try to improve the performance when the limit is equal or smaller than the chunksize on the axis (only one chunk would overlap). Is there any kind of benchmark that I can use to see if there is any real benefit from using map overlap over cumsum when the n is small?

    if n is not None and n <= array.chunksize[axis]:
        return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis)
dcherian commented 1 week ago

I don't have one at hand. @phofl might have suggestions.

Note that chunks can be non-uniform, so you'll have to check every element of that array.chunks[axis] tuple.

phofl commented 1 week ago

Yes!

The highlevel_api function in https://github.com/coiled/benchmarks/blob/main/tests/geospatial/workloads/climatology.py

josephnowak commented 1 week ago

thanks, I will try to run those benchmarks before modifying this PR, but from my understanding it should generate fewer tasks using the map overlap on a small N.

dcherian commented 1 week ago

Yes I believe that's why @phofl proposed that optimization. We just need to make sure to trigger it in the right circumstances. Please add a test for non-uniform chunksizes where map_overlap would give you the wrong answer

josephnowak commented 1 week ago

After understanding more about the code that @phofl did, I think that the implementation works in the general case, the main issue with it is that it can load the whole array in memory when the N is equal to the size of the array (based on my understanding of how the map overlaps works) when I saw the PR I thought it was the same implementation than before, but what he did was to define the overlap length based on the limit and the previous algorithm (the one before the implementation with the cumreduction) applied a map_blocks using the push function first and later a map overlap but with a single element of overlapping instead of N which generated the issue when one or more chunks was completely empty. In summary, I think that the best is to use his algorithm when N is smaller than the chunk size of the axis, so we can avoid loading more than 2 chunks at the same time.

dcherian commented 1 week ago

doesn't (n <= array.chunks[axis]).all() handle that case. It wouldn't trigger if n is greater than the chunksize, so you'll never load the whole thing in to memory at once?

In any case, how do you feel about merging this and optimizing later? I believe this is already a massive improvement.

josephnowak commented 1 week ago

Yes, that condition would prevent loading all the chunks, I just wanted to mention that using that algorithm in the general case would be problematic due that it would load all the chunks at once but it would give the expected result at the end.

I'm okay with merging it without the special case optimization and adding that in a further PR, but I'm also okay with adding that special case on this PR, I let you the final decision, I already have it implemented on my local.

dcherian commented 1 week ago

Ah then why don't you push what you have.

phofl commented 1 week ago

doesn't (n <= array.chunks[axis]).all() handle that case. It wouldn't trigger if n is greater than the chunksize, so you'll never load the whole thing in to memory at once?

Yes this is what I wanted to do as well (I got side-tracked over the last few days unfortunately). That is memory efficient and helps in some cases

josephnowak commented 1 week ago

I didn't push it because the PR had a plan to merge the label, so I didn't want to modify the PR without your approval.

And sorry phofl for not getting the idea of your original algorithm the first time that I saw your PR, the good part is that now the algorithm behaves better in all the scenarios.

I'm not sure is why the tests are not running, should the conflicts be resolved first to run the tests?

dcherian commented 1 week ago

Nice work @josephnowak and @phofl