pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.63k stars 1.09k forks source link

DataArray.quantile is very slow for Dask applications (and generally) #9692

Closed phofl closed 1 week ago

phofl commented 3 weeks ago

What happened?

I've recently run into this with a real world workload where calling quantile on a DataArray was approx. 30x-40x slower than just calling median (6 minutes to 2.5 hours). We profiled this a little and it basically goes down to the NumPy implementation of quantiles and the shapes of our chunks. We had a relatively small time axis (i.e. 50-120 elements).

NumPy has a 1d quantile function only, so it calls apply_along_axis and basically iterates over the other dimensions which clocks up the GIL quite badly (it's also not very fast generally speaking, but running single threaded doesn't make it that bad). Median has a special implementation here that avoids this problem to some degree.

Here is a snapshot of a dask Dashboard running quantile

Screenshot 2024-10-28 at 22 33 52

The purple one is np.nanquantile, the green one is the custom implementation. These are workers with 2 threads, it gets exponentially worse with more threads because they are all locking each other up. 4 threads makes the runtime of a single task balloon to 220s

cc @dcherian

What did you expect to happen?

Things should be faster.

We can do 2 things on the Dask side to make things work better:

I made some timings that are shown in the screen shot above. The custom implementation runs in 1.3 seconds (the green one). The custom implementation will only cover the most common cases and dispatch to the numpy quantile function otherwise.

This will (by extension) also help groupby quantile if I understand things correctly, since it's just calling into the regular quantile per group as far as I understand this.

Minimal Complete Verifiable Example

import dask.array as da
import numpy as np
import xarray as xr

darr = xr.DataArray(
    da.random.random((8944, 7531, 50), chunks=(904, 713, -1)),
    dims=["x", "y", "time"],
)
darr.quantile(q=0.75, dim="time")

MVCE confirmation

Relevant log output

No response

Anything else we need to know?

No response

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 15:57:01) [Clang 17.0.6 ] python-bits: 64 OS: Darwin OS-release: 23.3.0 machine: arm64 processor: arm byteorder: little LC_ALL: None LANG: de_DE.UTF-8 LOCALE: ('de_DE', 'UTF-8') libhdf5: 1.14.3 libnetcdf: None xarray: 2024.9.0 pandas: 2.2.2 numpy: 1.26.4 scipy: 1.14.0 netCDF4: None pydap: None h5netcdf: 1.3.0 h5py: 3.12.1 zarr: 2.18.2 cftime: 1.6.4 nc_time_axis: None iris: None bottleneck: 1.4.1 dask: 2024.10.0+8.g6a18fe07 distributed: 2024.10.0+11.gafa6e8db matplotlib: None cartopy: None seaborn: None numbagg: None fsspec: 2024.6.1 cupy: None pint: None sparse: 0.15.4 flox: 0.9.9 numpy_groupies: 0.11.2 setuptools: 75.1.0 pip: 24.2 conda: None pytest: 8.3.3 mypy: None IPython: 8.28.0 sphinx: None
max-sixty commented 3 weeks ago

Does installing numbagg help? (I'm not sure it will in the dask case, would need to look more, but worth checking whether that eliminates this all...)

phofl commented 3 weeks ago

This helps indeed (10s instead of 60s) and also removes the GIL issue a lot since 4 threads perform equal to 2. So this is only relevant if numbagg is not installed I guess. Thanks!

The downside is so bad that I would be open to add a Dask implementation anyway. It made the real workload I had basically unrunnable.

slevang commented 3 weeks ago

np.nanquantile is known to be horrendously slow for ND when the aggregation dimension is small. I didn't realize how much numbagg has improved this, very nice. On a subset of your example I get:

numbagg=True, skipna=True
Wall time: 12.2 s

numbagg=True, skipna=False
Wall time: 2.45 s

numbagg=False, skipna=True
Wall time: 9min 27s

numbagg=False, skipna=False
Wall time: 2.77 s

So if there is room for improvement, it seems like:

  1. a better numpy implementation
  2. adding the other interpolation methods to numbagg
phofl commented 3 weeks ago

np.nanquantile is known to be https://github.com/numpy/numpy/issues/16575 for ND when the aggregation dimension is small.

Thanks for linking the issue.

So if there is room for improvement, it seems like:

a better numpy implementation adding the other interpolation methods to numbagg

While improving NumPy would be great, that's not something I myself can take on. Numbagg itself is great, but only a small number of xarray users seem to be using it. It has 70k monthly pypi downloads compared to xarray with 6 million (downloads aren't great but it's at least a very rough estimator).

I am not that worried about performance by itself, the main problem for Dask is that nanquantile in NumPy blocks the Gil so the threads are locking each other up, exponentially worsening the performance issues. Using 4 threads per worker increases runtime per task to 220s in the example above, which is pretty bad.

What I am curious about is if Xarray would call a Dask Array quantile / nanquantile implementation if we add this instead of going through apply_ufunc? Basically what xarray is doing for all other reductions like median and friends.

dcherian commented 3 weeks ago

What I am curious about is if Xarray would call a Dask Array quantile / nanquantile implementation if we add this instead of going through apply_ufunc? Basically what xarray is doing for all other reductions like median and friends.

Yes. Looks like i added the apply-ufunc implementation to work around dask not having it? https://github.com/pydata/xarray/pull/3559

dcherian commented 3 weeks ago

The easiest way to do this would be to just adapt the existing wrapper to handle dask and have dask="allowed" in the current apply_ufunc call.

The wrapper should go in https://github.com/pydata/xarray/blob/main/xarray/core/duck_array_ops.py and any dask-specific back-compat code can go in https://github.com/pydata/xarray/blob/main/xarray/core/dask_array_ops.py

max-sixty commented 3 weeks ago

Numbagg itself is great, but only a small number of xarray users seem to be using it. It has 70k monthly pypi downloads compared to xarray with 6 million (downloads aren't great but it's at least a very rough estimator).

I would also be open to increasing this number :)

(and jokes aside, there are discussions about moving xarray to xarray-core and having an xarray package with optional-but-recommended dependencies, which would make the default install of xarray include libraries which give it decent perf, such as numbagg)