pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.61k stars 1.08k forks source link

Support segmented scans #9229

Open dcherian opened 3 months ago

dcherian commented 3 months ago

Is your feature request related to a problem?

It is pretty common to want to run cumsum and have the sum reset when a boolean flag array is 1. This is so common it has its own Wikipedia page and is discussed in Blelloch (1993) (Section 1.5)

Here's a real example of someone trying to implement it in a fairly roundabout way.

time_cumsum = cube.cumsum(dim = 'time')
cumsum = time_cumsum - time_cumsum.where(cube== 0).ffill(dim = 'time').fillna(0)

We have a few options to implement it:

  1. We could introduce a new method DataArray.segmented_scan(flags, op="sum") or a new class DataArray.segment.cumsum()? A dask/cubed friendly version that does all of this in a single scan should be fairly straightforward to write (and similar to our ffill, bfill wrappers).

  2. In a way this generalizes resample and it just struck me that the example above could be written as the following, which should be OK once flox adds scans

    group_idx = (cube == 0).cumsum('time')
    cubed.groupby(group_idx).cumsum()
    1. We could use our new Grouper functionality to expose a "flag" grouper that hides the group_idx = (cube == 0).cumsum('time') line.

My concern with (2) and (2.i) is that they are not at all obvious for most of our userbase.

TomNicholas commented 3 months ago

@dcherian I feel like you're practically the only person who would have realized that this is expressible as (2) 😅

I like the idea of adding some kind of cumsum syntactic sugar, especially if the underlying implementation can be in terms of groupby so it doesn't add much maintenance burden.

max-sixty commented 3 months ago

Brief reminder that we have .cumulative, so we could use that to add in some complications if needed!