pangeo-data / rechunker

Disk-to-disk chunk transformation for chunked arrays.
https://rechunker.readthedocs.io/
MIT License
163 stars 25 forks source link

Specify chunk plan for all variables via coordinates? #83

Closed rsignell-usgs closed 2 years ago

rsignell-usgs commented 3 years ago

When rechunking with tools like Unidata's nccopy, we can specify the desired chunk sizes for each coordinate, and then all variables that depend on those coordinates will be chunked accordingly, for example:

nccopy -d1 -c time/10,lon/20,lat/30 netcdf.nc netcdf_chunked.nc

Is it possible to accomplish the same with rechunker?

The docs seem to indicate that you need to provide a dictionary that contains every variable that you want to chunk.
I'm currently working on a dataset with 180 data variables!

jbusecke commented 3 years ago

I have written a wrapper for a project I am working on

def rechunker_wrapper(source_store, target_store, temp_store, chunks=None, mem="2GiB", consolidated=False, verbose=True):

    # erase target and temp stores
    if temp_store.exists():
        shutil.rmtree(temp_store)

    if target_store.exists():
        shutil.rmtree(target_store)

    if isinstance(source_store, xr.Dataset):
        g = source_store  # trying to work directly with a dataset
        ds_chunk = g
    else:
        g = zarr.group(str(source_store))
        # get the correct shape from loading the store as xr.dataset and parse the chunks
        ds_chunk = xr.open_zarr(str(source_store))

    group_chunks = {}

    for var in ds_chunk.variables:
        # pick appropriate chunks from above, and default to full length chunks for dimensions that are not in `chunks` above.
        group_chunks[var] = []
        for di in ds_chunk[var].dims:
            if di in chunks.keys():
                if chunks[di] > len(ds_chunk[di]):
                    group_chunks[var].append(len(ds_chunk[di]))
                else:
                    group_chunks[var].append(chunks[di])

            else:
                group_chunks[var].append(len(ds_chunk[di]))

        group_chunks[var] = tuple(group_chunks[var])

    rechunked = rechunk(g, group_chunks, mem, target_store, temp_store=temp_store)
    rechunked.execute()

Bunch of stuff in there is not really relevant, but this logic makes it pretty easy to just provide a 'chunk dictionary' (like in xarray.

E.g. you could do: rechunker_wrapper(source_store, target_store, temp_store, chunks={lon: 4, lat:5}). Perhaps this is something that could be implemented in the rechunk function itself?

rsignell-usgs commented 3 years ago

@jbusecke , this is perfect! I was thinking that I might have to write one of these myself, but you've already done it, and done a better job than I would have! Thank you! (and we should think about adding this to the rechunker repo)

jbusecke commented 3 years ago

I could try to put in a PR next week. If there are comments on how to modify the above I would appreciate it.

rsignell-usgs commented 3 years ago

@jbusecke, awesome. If you could include the part that handles consolidated=True that would be great.

BTW, in my test case that was failing, I tried looping over all 128 dataset variables trying to find out the culprit:

for var in ds.variables:
    print(var)
    rechunker_wrapper(ds[[var]], target_store, temp_store, mem='3GiB', consolidated=True, 
    chunks={'s_rho':1, 's_w':1, 'ocean_time':180, 'ND':1, 'Nbed':1})

And that worked - it died on variable 120, a wave period variable that had units of 'seconds' and therefore was being interpreted as a datetime interval and getting converted into datetime64 instead of float. So I just needed to add decode_timedelta=False in the open_mfdataset call:

with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ds = xr.open_mfdataset(flist, chunks={'ocean_time':1}, decode_timedelta=False, 
                           data_vars="minimal", coords="minimal", compat="override")

and everything worked smoothly writing to zarr, even the scalar vars!

rsignell-usgs commented 3 years ago

@jbusecke , you mentioned you clipped just a part out of that rechunker_wrapper function. Can you point me toward the whole function? Or include here?

jbusecke commented 3 years ago

This should be the whole thing as a wrapper

import shutil
import zarr
from rechunker.api import rechunk
import xarray as xr
import pathlib

def rechunker_wrapper(source_store, target_store, temp_store, chunks=None, mem="2GiB", consolidated=False, verbose=True):

    # convert str to paths
    def maybe_convert_to_path(p):
        if isinstance(p, str):
            return pathlib.Path(p)
        else:
            return p

    source_store = maybe_convert_to_path(source_store)
    target_store = maybe_convert_to_path(target_store)
    temp_store = maybe_convert_to_path(temp_store)

    # erase target and temp stores
    if temp_store.exists():
        shutil.rmtree(temp_store)

    if target_store.exists():
        shutil.rmtree(target_store)

    if isinstance(source_store, xr.Dataset):
        g = source_store  # trying to work directly with a dataset
        ds_chunk = g
    else:
        g = zarr.group(str(source_store))
        # get the correct shape from loading the store as xr.dataset and parse the chunks
        ds_chunk = xr.open_zarr(str(source_store))

    # convert all paths to strings
    source_store = str(source_store)
    target_store = str(target_store)
    temp_store = str(temp_store)

    group_chunks = {}
    # newer tuple version that also takes into account when specified chunks are larger than the array
    for var in ds_chunk.variables:
        # pick appropriate chunks from above, and default to full length chunks for dimensions that are not in `chunks` above.
        group_chunks[var] = []
        for di in ds_chunk[var].dims:
            if di in chunks.keys():
                if chunks[di] > len(ds_chunk[di]):
                    group_chunks[var].append(len(ds_chunk[di]))
                else:
                    group_chunks[var].append(chunks[di])

            else:
                group_chunks[var].append(len(ds_chunk[di]))

        group_chunks[var] = tuple(group_chunks[var])
    if verbose:
        print(f"Rechunking to: {group_chunks}")
    rechunked = rechunk(g, group_chunks, mem, target_store, temp_store=temp_store)
    rechunked.execute()
    if consolidated:
        if verbose:
            print('consolidating metadata')
        zarr.convenience.consolidate_metadata(target_store)
    if verbose:
        print('removing temp store')
    shutil.rmtree(temp_store)
    if verbose:
        print('done')
rsignell-usgs commented 3 years ago

@jbusecke , tried it with all the options (consolidated=True, verbose=True), and it worked great! The only thing I think I might suggest is making the execute statement read:

rechunked.execute(retries=10)

to make it clear that retries can be specified. I guess that could be passed in instead of hard-wired.

jbusecke commented 3 years ago

Cool! Glad it was helpful. Still a bit bogged down with other stuff, but might be able to start a PR tonight. That tip is helpful too!

pmav99 commented 3 years ago

@jbusecke thanks for this, works great!

A minor simplification might be:

    # erase target and temp stores
-    if temp_store.exists():
-        shutil.rmtree(temp_store)
-
-    if target_store.exists():
-        shutil.rmtree(target_store)
+    shutil.rmtree(temp_store, ignore_errors=True)
+    shutil.rmtree(target_store, ignore_errors=True)

which probably removes the need for maybe_convert_to_path(), too.

I also think that chunks should be a mandatory argument. If the default value of None is used, then if di in chunks.keys() will raise an AttributeError

pmav99 commented 3 years ago

Extracting the logic to a separate function and writing tests for it:

from collections import defaultdict
from typing import Dict
from typing import Tuple

import numpy as np
import pandas as pd
import pytest
import xarray as xr

def compute_chunks(ds: xr.Dataset, target_chunks: Dict[str, int]) -> Dict[str, Tuple[int]]:
    """ 
    Calculate ``chunks`` suitable for ``rechunker.rechunk()`` using ``target_chunks`` and the actual dimensions from ``ds``.

    - If a dimension is missing from ``target_chunks`` then use the full length from ``ds``.
    - If a chunk in ``target_chunks`` is larger than the full length of the variable in ``ds``, 
      then, again, use the full length from the dataset.

    """
    group_chunks = defaultdict(list)
    for var in ds.variables:
        for dim in ds[var].dims:
            if dim in target_chunks.keys() and target_chunks[dim] <= len(ds[dim]):
                group_chunks[var].append(target_chunks[dim])
            else:
                group_chunks[var].append(len(ds[dim]))

    # rechunk() expects chunks values to be a tuple. So let's convert them
    group_chunks_tuples = {var: tuple(chunks) for (var, chunks) in group_chunks.items()}
    return group_chunks_tuples

@pytest.fixture(scope="session")
def chunk_ds():
    lon = np.arange(-180, 180)
    lat = np.arange(-90, 90)
    timestamps = pd.date_range("2001-01-01", "2001-12-31", name="time", freq="D")
    ds = xr.Dataset(
        data_vars=dict(
            aaa=(
                ["lon", "lat", "time"],
                np.random.randint(0, 101, (len(lon), len(lat), len(timestamps))),
            )
        ),
        coords=dict(
            lon=lon,
            lat=lat,
            time=timestamps,
        ),
    )
    return ds

# fmt: off
@pytest.mark.parametrize(
    "target_chunks,expected",
    [
        pytest.param(dict(lon=10), dict(aaa=(10, 180, 365), lon=(10,), lat=(180,), time=(365,)), id="just lon chunk"),
        pytest.param(dict(lat=10), dict(aaa=(360, 10, 365), lon=(360,), lat=(10,), time=(365,)), id="just lat chunk"),
        pytest.param(dict(time=10), dict(aaa=(360, 180, 10), lon=(360,), lat=(180,), time=(10,)), id="just time chunk"),
        pytest.param(dict(lon=10, lat=10, time=10), dict(aaa=(10, 10, 10), lon=(10,), lat=(10,), time=(10,)), id="all dimensions - equal chunks"),
        pytest.param(dict(lon=10, lat=20, time=30), dict(aaa=(10, 20, 30), lon=(10,), lat=(20,), time=(30,)), id="all dimensions - different chunks"),
        pytest.param(dict(lon=1000), dict(aaa=(360, 180, 365), lon=(360,), lat=(180,), time=(365,)), id="lon chunk greater than size"),
        pytest.param(dict(lat=1000), dict(aaa=(360, 180, 365), lon=(360,), lat=(180,), time=(365,)), id="lat chunk greater than size"),
        pytest.param(dict(time=1000), dict(aaa=(360, 180, 365), lon=(360,), lat=(180,), time=(365,)), id="time chunk greater than size"),
        pytest.param(dict(lon=1000, lat=1000, time=1000), dict(aaa=(360, 180, 365), lon=(360,), lat=(180,), time=(365,)), id="all chunks greater than size"),
    ],
)
# fmt: on
def test_compute_chunks_from_ds(chunk_ds: xr.Dataset, target_chunks, expected) -> None:
    result = compute_chunks(ds=chunk_ds, target_chunks=target_chunks)
    assert expected == result
rabernat commented 3 years ago

Thanks for everyone's work here.

We can definitely support this inside rechunker. All the code and even tests appear to be already written. Could someone please start a PR?

jbusecke commented 3 years ago

Is this still outstanding? I recently ran into this issue again, and could take a shot at it.

jbusecke commented 3 years ago

I just took a first shot at implementing this within rechunk. Very keen on feedback.

I also think we should deal with the automatic removal of existing stores in a seperate PR.