pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.61k stars 1.08k forks source link

`concat()` very slow when inserting `NaN` into Dask arrays #9496

Open pschlo opened 1 month ago

pschlo commented 1 month ago

What is your issue?

Given the following situation:

When I concat() them along dim1, xarray extends the variables that appear in the first Dataset but not in the second Dataset with NaN. I would expect this to be lazy and to execute almost instantly, but it turns out to be very slow on my machine.

Example code:

import xarray as xr
import dask.array as da
import numpy as np

ds1 = xr.Dataset(
    data_vars=dict(
        var1=('dim1', da.arange(10, dtype=np.float64, chunks=-1)),
        var2=('dim1', da.arange(10, dtype=np.float64, chunks=-1)),
        var3=('dim1', da.arange(10, dtype=np.float64, chunks=-1)),
        var4=('dim1', da.arange(10, dtype=np.float64, chunks=-1)),
        var5=('dim1', da.arange(10, dtype=np.float64, chunks=-1)),
        var6=('dim1', da.arange(10, dtype=np.float64, chunks=-1)),
        var7=('dim1', da.arange(10, dtype=np.float64, chunks=-1))
    ),
)

ds2 = xr.Dataset(
    data_vars=dict(
        var1=('dim1', da.arange(100_000, dtype=np.float64, chunks=20_000)),
    ),
)

print(ds1)
print('var1 chunks:', ds1['var1'].chunksizes)

print()
print(ds2)
print('var1 chunks:', ds2['var1'].chunksizes)

print()
concat = xr.concat([ds1, ds2], dim='dim1')
print(concat)
print('var1 chunks:', concat['var1'].chunksizes)
print('var2 chunks:', concat['var2'].chunksizes)

Output:

<xarray.Dataset> Size: 560B
Dimensions:  (dim1: 10)
Dimensions without coordinates: dim1
Data variables:
    var1     (dim1) float64 80B dask.array<chunksize=(10,), meta=np.ndarray>
    var2     (dim1) float64 80B dask.array<chunksize=(10,), meta=np.ndarray>
    var3     (dim1) float64 80B dask.array<chunksize=(10,), meta=np.ndarray>
    var4     (dim1) float64 80B dask.array<chunksize=(10,), meta=np.ndarray>
    var5     (dim1) float64 80B dask.array<chunksize=(10,), meta=np.ndarray>
    var6     (dim1) float64 80B dask.array<chunksize=(10,), meta=np.ndarray>
    var7     (dim1) float64 80B dask.array<chunksize=(10,), meta=np.ndarray>
var1 chunks: Frozen({'dim1': (10,)})

<xarray.Dataset> Size: 800kB
Dimensions:  (dim1: 100000)
Dimensions without coordinates: dim1
Data variables:
    var1     (dim1) float64 800kB dask.array<chunksize=(20000,), meta=np.ndarray>
var1 chunks: Frozen({'dim1': (20000, 20000, 20000, 20000, 20000)})

<xarray.Dataset> Size: 6MB
Dimensions:  (dim1: 100010)
Dimensions without coordinates: dim1
Data variables:
    var1     (dim1) float64 800kB dask.array<chunksize=(10,), meta=np.ndarray>
    var2     (dim1) float64 800kB dask.array<chunksize=(10,), meta=np.ndarray>
    var3     (dim1) float64 800kB dask.array<chunksize=(10,), meta=np.ndarray>
    var4     (dim1) float64 800kB dask.array<chunksize=(10,), meta=np.ndarray>
    var5     (dim1) float64 800kB dask.array<chunksize=(10,), meta=np.ndarray>
    var6     (dim1) float64 800kB dask.array<chunksize=(10,), meta=np.ndarray>
    var7     (dim1) float64 800kB dask.array<chunksize=(10,), meta=np.ndarray>
var1 chunks: Frozen({'dim1': (10, 20000, 20000, 20000, 20000, 20000)})
var2 chunks: Frozen({'dim1': (10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, ...

The last output line is followed by many more 10s.

This takes about 10-20 seconds to run on my machine. Is there any reason for this being so slow? I would've expected the code to execute almost instantly, such that the NaN chunks are being added lazily, e.g. upon calling compute().

Here is my output of xr.show_versions():

INSTALLED VERSIONS
------------------
commit: None
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
python-bits: 64
OS: Linux
OS-release: 5.15.133.1-microsoft-standard-WSL2
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: C.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.14.2
libnetcdf: 4.9.3-development

xarray: 2024.9.0
pandas: 2.2.2
numpy: 1.26.4
scipy: 1.14.1
netCDF4: 1.7.1.post2
pydap: None
h5netcdf: 1.3.0
h5py: 3.7.0
zarr: 2.12.0
cftime: 1.6.4
nc_time_axis: 1.4.1
iris: None
bottleneck: 1.4.0
dask: 2024.9.0
distributed: 2024.9.0
matplotlib: 3.9.2
cartopy: None
seaborn: 0.13.2
numbagg: 0.8.1
fsspec: 2024.9.0
cupy: None
pint: None
sparse: None
flox: 0.9.11
numpy_groupies: 0.11.2
setuptools: 74.1.2
pip: 22.0.2
conda: None
pytest: None
mypy: None
IPython: 8.27.0
sphinx: None
welcome[bot] commented 1 month ago

Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!

slevang commented 1 month ago

Looks like this was introduced by dask==2024.8 and the new shuffle algorithm. If run with an earlier version the example gives a single 100,010 chunk for the expanded variables, but now gives size 10 chunking over the whole array. Ideally we want chunks=((10, 20000, 20000, ...),) right?

pschlo commented 1 month ago

Looks like this was introduced by dask==2024.8 and the new shuffle algorithm. If run with an earlier version the example gives a single 100,010 chunk for the expanded variables, but now gives size 10 chunking over the whole array. Ideally we want chunks=((10, 20000, 20000, ...),) right?

This does indeed seem more reasonable. I also tested this with an earlier version but found it to be very slow as well. I guess that shuffling is required for the more general concat() case, but in the mentioned case, it should be easy to compute, no? I wrote the following function to circumvent this in my project:

import xarray as xr
import dask.array as da
import numpy as np
from collections.abc import Collection

def xr_prefill_concat(datasets: Collection[xr.Dataset], dim: str, *args, **kwargs):
    """Concatenate Dask Datasets by first ensuring they all have the same data_vars"""
    datasets = [ds.copy() for ds in datasets]

    def fill_vars(ds: xr.Dataset, vars: set[str]):
        missing_vars = vars - set(ds.data_vars)
        if not missing_vars:
            return

        # use chunk size of any data variable
        dataarr_chunksizes = next(iter(ds.data_vars.values())).chunksizes
        if not dataarr_chunksizes:
            raise ValueError("Dataset must be backed by Dask")
        chunk_size = dataarr_chunksizes[dim][0]

        for var in missing_vars:
            ds[var] = (dim, da_wrap.full((ds.sizes[dim],), np.nan, chunks=chunk_size))

    all_vars: set[str] = set().union(*(ds.data_vars for ds in datasets))
    for ds in datasets:
        fill_vars(ds, all_vars)

    return xr.concat(datasets, dim=dim, *args, **kwargs)
dcherian commented 1 month ago

cc @phofl

phofl commented 1 month ago

Thanks for the ping, will take a look

dcherian commented 1 month ago

Looking at this again, I'm not sure we can do much to choose the "right" chunksizes here. Each variable is treated independently, so when concatenating var2 we don't have any context to do any different.

Also I can't reproduce, the example takes 1.5s on my machine.

phofl commented 1 month ago

Would something like a chunk size hint on the dask side help here? (Note: this might not be a viable suggestion, haven’t had the chance to look yet)