xarray-contrib / flox

Fast & furious GroupBy operations for dask.array
https://flox.readthedocs.io
Apache License 2.0
124 stars 18 forks source link

Rework cohorts algorithm and heuristics #396

Open dcherian opened 2 months ago

dcherian commented 2 months ago

Example ERA5

import dask.array
import flox.xarray
import pandas as pd
import xarray as xr

dims = ("time", "level", "lat", "lon")
# nyears is number of years, adjust to make bigger,
# full dataset is 60-ish years.
nyears = 10
shape = (nyears * 365 * 24, 37, 721, 1440)
chunks = (1, -1, -1, -1)

ds = xr.Dataset(
    {"U": (dims, dask.array.random.random(shape, chunks=chunks))},
    coords={"time": pd.date_range("2001-01-01", periods=shape[0], freq="H")},
)
ds.resample(time="D").mean()

All the overhead is in subset_to_blocks

dcherian commented 2 months ago

Now the overhead is in creating 1000s of dask arrays: here's the profile of dask.array.Array.__new__ which is totally useless since I'm already passing in valid chunks and meta!

  1331     14606  124899000.0   8551.2     10.5          meta = meta_from_array(meta, dtype=dtype)
  1332                                           
  1333                                                   if (
  1334     14606    2636000.0    180.5      0.2              isinstance(chunks, str)
  1335     14606    2500000.0    171.2      0.2              or isinstance(chunks, tuple)
  1336      3656     510000.0    139.5      0.0              and chunks
  1337      3656    5453000.0   1491.5      0.5              and any(isinstance(c, str) for c in chunks)
  1338                                                   ):
  1339                                                       dt = meta.dtype
  1340                                                   else:
  1341     14606    1583000.0    108.4      0.1              dt = None
  1342     14606  259910000.0  17794.7     21.9          self._chunks = normalize_chunks(chunks, shape, dtype=dt)
  1343     14606    5396000.0    369.4      0.5          if self.chunks is None:
  1344                                                       raise ValueError(CHUNKS_NONE_ERROR_MESSAGE)
  1345     14606  587524000.0  40224.8     49.5          self._meta = meta_from_array(meta, ndim=self.ndim, dtype=dtype)

cc @phofl

The current approach in flox is pretty flawed --- it will create 10,000-ish layers hehe but perhaps there are some quick wins.

dcherian commented 1 month ago

Alternatively, we need a better approach. The current way is to select out a cohort, make that a new dask array, apply the tree reduction, then concatenate across cohorts.

phofl commented 1 week ago

Just linking my other comment here: https://github.com/dask/dask/issues/11026#issuecomment-2455922728

dcherian commented 1 week ago

Yes some clever use of shuffle might be the way here.

phofl commented 5 days ago

Can you help me understand a bit better what high-level API would be useful here?

My understanding of the steps involved is as follows:

The pattern of creating a new array per group comes with significant overhead and also makes life for the scheduler harder. A type of highlevel API that would be useful here (based on my understanding):

Basically each array now would be a branch in the future, kind of like a shuffle but with no guarantee that each group ends up in a single array. Is that understanding roughly correct?

dcherian commented 5 days ago

yes some kind of shuffle will help, but the harder bit is the "tree reduction on each branch". That API basically takes Arrays, and uses Arrays internally.

PS: I'm low on bandwidth at the moment, and can only really engage here in two weeks. Too many balls in the air!

phofl commented 5 days ago

No worries, let's just tackle this a bit later then.

dcherian commented 5 days ago

While working on my talk, I just remembered that you all had fixed a major scheduling problem that made "cohorts" better than "map-reduce" in many more cases. That isn't true anymore, so we should revisit the algorithm and the heuristics for when it's a good automatic choice.

phofl commented 5 days ago

Can you elaborate a bit more? Did we break something in Dask since we fixed the ordering?

Edit: I think I misread your comment. You mean that the improved ordering behavior made map-reduce the better choice not that we broke something since we fixed the problem, correct?

dcherian commented 5 days ago

No your fix makes my heuristics useless :P we need to update them and choose "map-reduce" in more cases than we do currently. The graph is significantly better and easier to schedule with "map-reduce"

phofl commented 5 days ago

Puuuh, you had me worried there for a bit 😂