pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.57k stars 1.07k forks source link

`rolling(...).construct(...)` blows up chunk size #9550

Open hendrikmakait opened 4 days ago

hendrikmakait commented 4 days ago

What happened?

When using `rolling(...).construct(...) in https://github.com/coiled/benchmarks/pull/1552, I noticed that my Dask workers died running out of memory because the chunk sizes get blown up.

What did you expect to happen?

Naively, I would expect rolling(...).construct(...) to try and keep chunk sizes constant instead of blowing them up quadratic to the window size.

Minimal Complete Verifiable Example

import dask.array as da
import xarray as xr

# Construct dataset with chunk size of (400, 400, 1) or 1.22 MiB
ds = xr.Dataset(
        dict(
            foo=(
                ["latitute", "longitude", "time"],
                da.random.random((400, 400, 400), chunks=(-1, -1, 1)),
            ),
        )
    )

# Dataset now has chunks of size (400, 400, 100 100) or 11.92 GiB
ds = ds.rolling(time=100, center=True).construct("window")

MVCE confirmation

Relevant log output

No response

Anything else we need to know?

No response

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.12.6 | packaged by conda-forge | (main, Sep 11 2024, 04:55:15) [Clang 17.0.6 ] python-bits: 64 OS: Darwin OS-release: 23.6.0 machine: arm64 processor: arm byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: None libnetcdf: None xarray: 2024.7.0 pandas: 2.2.2 numpy: 1.26.4 scipy: 1.14.0 netCDF4: None pydap: None h5netcdf: None h5py: None zarr: 2.18.2 cftime: 1.6.4 nc_time_axis: None iris: None bottleneck: 1.4.0 dask: 2024.9.0 distributed: 2024.9.0 matplotlib: None cartopy: None seaborn: None numbagg: None fsspec: 2024.6.1 cupy: None pint: None sparse: 0.15.4 flox: 0.9.9 numpy_groupies: 0.11.2 setuptools: 73.0.1 pip: 24.2 conda: 24.7.1 pytest: 8.3.3 mypy: None IPython: 8.27.0 sphinx: None
welcome[bot] commented 4 days ago

Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!

dcherian commented 3 days ago

This is using the sliding_window_view trick under the hood, which composes badly with anything that does a memory copy (like weighted in your example)

https://github.com/dask/dask/blob/d45ea380eb55feac74e8146e8ff7c6261e93b9d7/dask/array/overlap.py#L808

We actually use this approach for .rolling.mean but are clever about handling memory copies under the hood (https://github.com/pydata/xarray/pull/4915).

I'm not sure what the right solution here is.

  1. Perhaps dask can automatically rechunk the dimensions that are being "slided over"? We'd want the new dimensions "window" to be singly-chunked by default I think.
  2. On the xarray side, a lot of the pain stems from automatically padding with NaNs in rolling.construct. This has downstream consequences (np.nanmean uses a memory copy for example). But this is a more complex fix: https://github.com/pydata/xarray/pull/5603

PS: I chatted with @phofl about this at FOSS4G. He has some context.

phofl commented 3 days ago

Yeah this is definitely on my todo list and @hendrikmakait and I chatted briefly about this today, there is definitely something we have to do

dcherian commented 3 days ago

I support the approach, but it'd be good to see the impact on ds.rolling().mean() which also uses construct but is clever about it to avoid the memory blowup.

hendrikmakait commented 3 days ago

I also wonder if instead of using rolling().construct().weighted().mean() there should just be something like rolling().weighted().mean() or rolling().mean(weights=...). From what I understand, the quadratic explosion of the shape and the chunks is not inherent to this computation but we could also solve it akin to a map_overlap computation.

dcherian commented 2 days ago

Yes, https://github.com/pydata/xarray/issues/3937, but we've struggled to move on that.

construct is a pretty useful escape hatch for custom workloads, so we should optimize for it behaving sanely.