Deltares / xugrid

Xarray and unstructured grids
https://deltares.github.io/xugrid/
MIT License
61 stars 8 forks source link

Automatically resize chunks after merging parallel results #252

Closed Huite closed 2 months ago

Huite commented 2 months ago

The current merging of parallel results can bring a lot of chunking overhead with it. In many cases, you will want to merge results, then reindex them to match an original grid topology. If the original chunks of the parallel partitions persist, the reindexing becomes wildly inefficient, as dask will try to respect the chunks.

image

It's clear the the reindex result is shuffled around, and the chunks should not be maintained. The only thing dask can do here -- I guess -- is schedule the operation per chunk, and then repeat them very inefficienctly.

What should likely be the default option here, is resizing all the chunks automatically when merging to ensure all the ugrid dimensions are in a single chunk:

from collections import ChainMap

ugrid_dims = ChainMap(*[grid.dimensions for grid in uds.ugrid.grids])

chunks = dict(uds.chunks)
for key, value in chunks.items():
    if key in ugrid_dims:
        chunks[key] = ugrid_dims[key]

rechunked = uds.chunk(chunks)

Almost all subsequent operations will run much smoother with a single spatial chunk. I was thinking of setting a merge_ugrid_chunks=True argument, but is there really any case where you would want to merge the grids, but not merge chunks? Maybe if you're interested in individual cells -- but again it likely won't be more efficient in almost all cases, I reckon.

veenstrajelmer commented 1 month ago

A reproducible example from the Verschillentool experiences:

import dfm_tools as dfmt

file_nc1 = r"P:\dflowfm\projects\2021_verschilanalyse\modellen_2022_2023\dcsm_0.5nm\computations\hist\2013-2017_structure_ini_aangepast\2024.01_DFM_OUTPUT_DCSM-FM_0_5nm_goed/*0*_map.nc"
uds1 = dfmt.open_partitioned_dataset(file_nc1)

file_nc2 = r"p:/dflowfm/projects/2021_verschilanalyse/modellen_2022_2023/dcsm_0.5nm/computations/hist/2013-2017/202301_DFM_OUTPUT_DCSM-FM_0_5nm/*0*_map.nc"
uds2 = dfmt.open_partitioned_dataset(file_nc2)

uds1_new = uds1.ugrid.reindex_like(uds2)

Gives the following warning:

>> xu.open_dataset() with 21 partition(s): 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 [nodomainfname] : 52.73 sec
>> xu.merge_partitions() with 21 partition(s): 3.83 sec
>> dfmt.open_partitioned_dataset() total: 56.68 sec
>> xu.open_dataset() with 21 partition(s): 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 [nodomainfname] : 8.88 sec
>> xu.merge_partitions() with 21 partition(s): 3.73 sec
>> dfmt.open_partitioned_dataset() total: 12.65 sec
C:\Users\veenstra\Anaconda3\envs\dfm_tools_env\Lib\site-packages\xarray\core\indexing.py:1620: PerformanceWarning: Slicing with an out-of-order index is generating 480 times more chunks
  return self.array[key]
C:\Users\veenstra\Anaconda3\envs\dfm_tools_env\Lib\site-packages\xarray\core\indexing.py:1620: PerformanceWarning: Slicing with an out-of-order index is generating 442 times more chunks
  return self.array[key]
C:\Users\veenstra\Anaconda3\envs\dfm_tools_env\Lib\site-packages\xarray\core\indexing.py:1620: PerformanceWarning: Slicing with an out-of-order index is generating 564 times more chunks
  return self.array[key]