SciTools / iris

A powerful, format-agnostic, and community-driven Python package for analysing and visualising Earth science data
https://scitools-iris.readthedocs.io/en/stable/
BSD 3-Clause "New" or "Revised" License
633 stars 283 forks source link

Computational performance of `iris.cube.Cube.aggregated_by` with lazy data #5455

Open bouweandela opened 1 year ago

bouweandela commented 1 year ago

Current behaviour

The method iris.cube.Cube.aggregated_by produces many tiny chunks and a very large graph when used with lazy data. When using this method as part of a larger computation, my dask graphs become so large that the computation fails to run. Even if the computation would run, it would be needlessly slow because of the many tiny chunks.

Expected behaviour

The method iris.cube.Cube.aggregated_by should respect the chunks of the input data as much as possible and produce a modestly sized dask graph.

Example code showing the current behaviour

Here is some example script that demonstrates the issue. The cube in the example represents 250 years of monthly data on a 300 x 300 spatial grid and the script computes the annual mean.

import dask.array as da
import iris.cube
import iris.coords
import numpy as np

cube = iris.cube.Cube(
    da.empty((250 * 12, 300, 300), chunks=('auto', None, None)),
    aux_coords_and_dims=[
        (iris.coords.AuxCoord(np.repeat(np.arange(250), 12), var_name="year"), 0)
    ],
)
print("Input data:")
print(f"{cube.lazy_data().chunks=}")
print(f"{cube.lazy_data().dask=}")

result = cube.aggregated_by("year", iris.analysis.MEAN)

print("Result:")
print(f"{result.lazy_data().chunks=}")
print(f"{result.lazy_data().dask=}")

prints the following

cube.lazy_data().chunks=((186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 24), (300,), (300,))
cube.lazy_data().dask=HighLevelGraph with 1 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7fc42cb9de10>
 0. empty_like-9ffb0624c777d40491118ccfa8357d00

Result:
result.lazy_data().chunks=((1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (300,), (300,))
result.lazy_data().dask=HighLevelGraph with 752 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7fc42c874990>
 0. empty_like-9ffb0624c777d40491118ccfa8357d00
 1. getitem-875ae1b7554c57fa4b243547465267f6
 2. mean_chunk-ee11e7299515ada537da00382a6512db
 3. mean_agg-aggregate-0e111cc07c3b62b465b43f12f0cd78e3
 4. getitem-50ebcc520f0974e2d456be1155a58224
 5. mean_chunk-87cd31da7a615d73c687b49bc2bd21e5
 6. mean_agg-aggregate-fe86efb73fd4e3936f2c495bcfeb0851

 ... etc
bouweandela commented 1 year ago

I found this project which seems relevant: https://github.com/xarray-contrib/flox

pp-mo commented 1 year ago

@scitools/peloton our standard answer to these problems is that getting Dask to work well is hard, and you should take out the data and use Dask (!) But also, we think the problem here is that aggregation can't easily respect chunks anyway. So if there is a repeating structure in the data, you really want to see if you can reshape the input with a year dimension. And then you wouldn't need to aggregate_by at all.

So, that approach probably waits on #5398. But we'd be keen to investigate if that can solve this case!

bouweandela commented 11 months ago

But also, we think the problem here is that aggregation can't easily respect chunks anyway.

Not generally, but I had a look at the use cases in ESMValCore and they are all time-related and the there is always an (almost) repeating structure in the coordinate. In some cases, e.g. day of the month, the structure can easily be made repeating by adding some extra days and then masking out the extra data.

So if there is a repeating structure in the data, you really want to see if you can reshape the input with a year dimension. And then you wouldn't need to aggregate_by at all.

Indeed slicing so all repeats have the same size, masking the extra data points introduced by slicing, and then reshaping and collapsing along the required dimension seems the best solution here. However, I struggled to reshape a cube, is there some feature available to do this that I overlooked? I'm also concerned that getting the right values for the coordinates is tricky when using Cube.collapsed instead of Cube.aggregated_by.

Maybe the best solution would be to use the algorithm described above internally in Cube.aggregated_by when the input data is a Dask array because the current implementation (iterating over an array index in Python and then using dask.array.stack) leads to task graphs that are so large that the scheduler hangs, so this is not a usable approach for Dask users.

So, that approach probably waits on https://github.com/SciTools/iris/issues/5398.

My feeling is that it depends on how numpy actually iterates over the input array when it computes the statistics along a dimension whether or not that will be beneficial.

dcherian commented 11 months ago

The method iris.cube.Cube.aggregated_by produces many tiny chunks and a very large graph when used with lazy data

Sounds familiar to me as an Xarray user :)

and they are all time-related and the there is always an (almost) repeating structure in the coordinate.

Absolutely. See https://flox.readthedocs.io/en/latest/implementation.html#method-cohorts

I recommend strongly you use flox. I wrote it a year ago to solve exactly these problems with Xarray.

  1. It vectorizes the core pure-numpy groupby using a number of options (numbagg, numpy-groupies, or its own internal ufunc.reduceat strategy).
  2. It implements the groupby-reduction as a tree reduction which works great with dask.
  3. It lets you opt-in to smarter graph construction that exploits patterns in the variable you're grouping by.

See https://xarray.dev/blog/flox