xarray-contrib / flox

Fast & furious GroupBy operations for dask.array
https://flox.readthedocs.io
Apache License 2.0
124 stars 18 forks source link

Set order='F' when raveling group_idx after broadcast #286

Closed dcherian closed 1 year ago

dcherian commented 1 year ago

This majorly improves the dim=... case for engine="flox" at least. xref https://github.com/xarray-contrib/flox/issues/281

I'm not sure if it is a regression for engine="numpy": EDIT it is a small regression, but a big one for engine="numbagg" (9ms -> 27ms).

We trade off a single bad reshape for array against argsorting both array and group_idx for a ~10-20x speedup

ds = xr.tutorial.load_dataset('air_temperature')
ds.groupby('lon').count(..., engine="flox")
dcherian commented 1 year ago
| Before [0cea9721] | After [375a9f7e] | Ratio | Benchmark (Parameter)                                              |
| <main>            | <opt-ravel>      |       |                                                                    |
|-------------------+------------------+-------+--------------------------------------------------------------------|
| 189±8ms           | 90.0±0.9ms       |  0.48 | reduce.ChunkReduce2DAllAxes.time_reduce('nanmean', 'bins', 'flox') |
| 160±10ms          | 74.7±0.3ms       |  0.47 | reduce.ChunkReduce2DAllAxes.time_reduce('nanmax', 'bins', 'flox')  |
| 160±8ms           | 75.3±0.6ms       |  0.47 | reduce.ChunkReduce2DAllAxes.time_reduce('nansum', 'bins', 'flox')  |
| 174±2ms           | 76.7±3ms         |  0.44 | reduce.ChunkReduce2DAllAxes.time_reduce('count', 'bins', 'flox')   |
| 178±9ms           | 79.2±1ms         |  0.44 | reduce.ChunkReduce2DAllAxes.time_reduce('mean', 'bins', 'flox')    |
| 158±6ms           | 65.9±0.7ms       |  0.42 | reduce.ChunkReduce2DAllAxes.time_reduce('sum', 'bins', 'flox')     |
| 165±9ms           | 66.1±2ms         |   0.4 | reduce.ChunkReduce2DAllAxes.time_reduce('max', 'bins', 'flox')     |
| 506±2ms           | 51.0±1ms         |   0.1 | reduce.ChunkReduce2DAllAxes.time_reduce('nanmean', 'None', 'flox') |
| 510±2ms           | 53.4±1ms         |   0.1 | reduce.ChunkReduce2DAllAxes.time_reduce('mean', 'None', 'flox')    |
| 539±20ms          | 40.5±0.2ms       |  0.08 | reduce.ChunkReduce2DAllAxes.time_reduce('nanmax', 'None', 'flox')  |
| 495±4ms           | 38.1±1ms         |  0.08 | reduce.ChunkReduce2DAllAxes.time_reduce('nansum', 'None', 'flox')  |
| 497±2ms           | 35.3±0.5ms       |  0.07 | reduce.ChunkReduce2DAllAxes.time_reduce('count', 'None', 'flox')   |
| 495±3ms           | 34.9±0.2ms       |  0.07 | reduce.ChunkReduce2DAllAxes.time_reduce('max', 'None', 'flox')     |
| 542±20ms          | 26.8±0.9ms       |  0.05 | reduce.ChunkReduce2DAllAxes.time_reduce('sum', 'None', 'flox')     |