xarray-contrib / flox

Fast & furious GroupBy operations for dask.array
https://flox.readthedocs.io
Apache License 2.0
124 stars 18 forks source link

Significantly faster cohorts detection. #272

Closed dcherian closed 1 year ago

dcherian commented 1 year ago

Closes #271

I was iterating over array.blocks to figure out the shape of each chunk.

When indexing this object, it creates a dask array per chunk, which is slow for many reasons

Replace with a function that calculates the chunk shape. On the arco era5 with 93044 time chunks, this is a speedup from infinite time to 840ms.

TIME = 92044
da = xr.DataArray(
    dask.array.ones((TIME, 721, 1440), chunks=(1, -1, -1)),
    dims=("time", "lat", "lon"),
    coords=dict(time=pd.date_range("1959-01-01", freq="6H", periods=TIME)),
)
%time xarray_reduce(da, da.time.dt.day, method="cohorts", func="any")
dcherian commented 1 year ago

The profile looks like

   225     92045   36723000.0    399.0      5.7      for idx, blockindex in enumerate(np.ndindex(array.numblocks)):
   226     92044  169811000.0   1844.9     26.5          chunkshape = get_chunk_shape(array_chunks, blockindex)
   227     92044  142355000.0   1546.6     22.2          blocks[idx] = np.full(chunkshape, idx)
   228         1  189328000.0    2e+08     29.5      which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1)

I strongly suspect we can do better. The tolist is copying blocks which should be unnecessary.

dcherian commented 1 year ago

Down to

546 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
dcherian commented 1 year ago

I strongly suspect we can do better.

Need a way to assign to a nested list. This is why I chose the numpy object array route in the first place :)

Punting to later, since this is already a massive improvement.

dcherian commented 1 year ago
Change Before [9f82e19d]
After [fa934069] Ratio Benchmark (Parameter)
- 192±1ms 172±4ms 0.9 cohorts.ERA5DayOfYear.time_graph_construct
- 186±3ms 158±2ms 0.85 cohorts.ERA5DayOfYearRechunked.time_graph_construct
- 95.5±0.9ms 68.6±0.9ms 0.72 cohorts.ERA5DayOfYearRechunked.time_find_group_cohorts
- 48.0±0.1ms 25.0±0.1ms 0.52 cohorts.ERA5DayOfYear.time_find_group_cohorts
- 44.5±2ms 18.6±0.6ms 0.42 cohorts.ERA5MonthHourRechunked.time_graph_construct
- 42.3±2ms 17.1±0.2ms 0.4 cohorts.ERA5MonthHour.time_graph_construct
- 9.52±0.04ms 3.72±0.01ms 0.39 cohorts.PerfectMonthly.time_graph_construct
- 9.60±0.2ms 3.72±0.02ms 0.39 cohorts.PerfectMonthlyRechunked.time_graph_construct
- 73.5±0.9ms 25.4±0.5ms 0.35 cohorts.time_cohorts_era5_single
- 31.2±0.2ms 7.58±0.1ms 0.24 cohorts.ERA5MonthHour.time_find_group_cohorts
- 34.1±0.1ms 7.80±0.2ms 0.23 cohorts.ERA5MonthHourRechunked.time_find_group_cohorts
- 6.87±0.03ms 1.02±0.04ms 0.15 cohorts.PerfectMonthlyRechunked.time_find_group_cohorts
- 6.95±0.1ms 1.00±0.01ms 0.14 cohorts.PerfectMonthly.time_find_group_cohorts
dcherian commented 1 year ago

Benchmarks seem to be broken after the numbagg PR. I'll fix in a new branch.