xarray-contrib / flox

Fast & furious GroupBy operations for dask.array
https://flox.readthedocs.io
Apache License 2.0
123 stars 16 forks source link

First execution of groupby on Xarray with Flox takes a lot of time #365

Open josephnowak opened 4 months ago

josephnowak commented 4 months ago

Hi, I have been using Flox with Xarray for some experiments and it works nicely in multiple cases, but when I try to create some unit tests with small data the execution times take longer than expected, for example, the code that I'm going to sent takes 60s to run but this only happens the first time that I run it which I suppose can be related to some just-in-time compilation. Also, I have noticed that it is printing arrays multiple times (please see the screenshot that shows the place that contains a possible unexpected print on the code), which I suppose can be affecting the performance (and again this only happens the first time that I run it which is even more strange).

Note: If I uninstall Flox everything works faster, that's why I'm reporting it in this repo and not in Xarray

import xarray as xr
import time
import pandas as pd

dates = pd.to_datetime(
    [
        "2020-12-30",
        "2021-01-02",
        "2021-01-03",
        "2021-01-05",
        "2021-01-07",
        "2021-01-09",
        "2021-01-11",
    ]
)
new_data = xr.DataArray(
    [
        [10, 9.1, 8, 7, 6, 2, 1],
        [6, 5, 7, 7, 4, 2, 1],
        [2, 9, 4, 5, 3, 2, 1],
        [-10, 2, 7, 7, 9, 2, 1],
        [4, 7, 3, -5, 6, 2, 1],
        [10, 1, 2, 7, 3, 2, 1],
        [12, 4, 5, -1, 0, 2, 1],
    ],
    coords={
        "date": dates,
        "a": ["a", "b", "c", "d", "e", "f", "g"],
    },
).chunk({"date": 2, "a": 5})

start = time.time()
new_data.groupby("a").max("a").compute()
print(time.time() - start)

image

Arrays printed during the execution: image

INSTALLED VERSIONS

commit: None python: 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:43:09) [GCC 12.3.0] python-bits: 64 OS: Linux OS-release: 4.14.275-207.503.amzn2.x86_64 machine: x86_64 processor: x86_64 byteorder: little LC_ALL: en_US.UTF-8 LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.14.3 libnetcdf: None

xarray: 2024.5.0 pandas: 2.2.1 numpy: 1.26.3 scipy: 1.11.4 netCDF4: None pydap: None h5netcdf: None h5py: 3.10.0 zarr: 2.16.1 cftime: None nc_time_axis: None iris: None bottleneck: 1.3.7 dask: 2024.5.1 distributed: 2024.5.1 matplotlib: 3.8.2 cartopy: None seaborn: 0.13.1 numbagg: 0.8.1 fsspec: 2023.12.2 cupy: None pint: None sparse: 0.15.1 flox: 0.9.7 numpy_groupies: 0.10.2 setuptools: 69.0.3 pip: 23.3.2 conda: 23.11.0 pytest: 7.4.4 mypy: None IPython: 8.20.0 sphinx: None

dcherian commented 4 months ago

You have numbagg installed so this is some JIT overhead. Some background here: https://github.com/numbagg/numbagg/pull/220.

You can avoid this by setting engine="flox" or engine="numpy".

That print statement is in the benchmarking code so it shouldn't affect what your;re seeing

I did make the choice to not allow selectively disabling numbagg. Would you prefer this happen? Or would you prefer that at the moment flox should prioritize another approach (probably flox over numbagg)?

josephnowak commented 4 months ago

I have used Numbagg in other projects and regularly I see a first execution of only a couple of seconds, not a minute but probably Flox is using many functions of it and that would explain the time. Thanks for the clarification.

For me, as changing the engine (I did not know that the engine parameter existed, so thank you) avoids the overhead of JIT for the smaller cases that I use on my test then I think is fine to let the current behavior.

Related to the printed array, I think the problem was caused by the Numbagg package, apparently in release 0.7.0 they let those print, which was later fixed. I updated the package and restarted the kernel of my Jupyter and they are no longer being printed. image

dcherian commented 4 months ago

Flox is using many functions of it

No I believe it compiles for all combinations of dtypes at one go. @max-sixty has more context on why the groupby aggregation functions take a while to compile.

Your point about "small" datasets is interesting. Perhaps the automatic choice should be to use numbagg only for "big enough" problems for some definition of big enough.

max-sixty commented 4 months ago

On the print statement — yes apologies. FWIW that was only the most current version for ~36 hours IIRC. There now a check to ensure this can't happen again.


On the compilation time — this is a flamegraph of the total time for your excellent example — indeed it's basically all compilation; the execution time are the small slither at the very end of each block.

I get an execution time of 18 seconds, ~8s of compilation for each.

image

We do have benchmarks for these, but haven't had them running on each commit! I could run them nightly so we have proper measurements.


There are two changes from numba which will make a big difference:

Currently neither of these are compatible with parallel. The first is being worked on upstream but not sure when to expect it; we probably shouldn't plan around them.


Partitioning based on the size of the array may make sense. Ofc if you're doing computations over lots of small arrays, you probably want to pay the upfront cost, but it would still be a reasonable heuristic for spotting this sort of workload.