Deltares / xugrid

Xarray and unstructured grids
https://deltares.github.io/xugrid/
MIT License
65 stars 8 forks source link

prevent inconsistent chunks error in `xu.merge_partitions()` #278

Closed veenstrajelmer closed 3 months ago

veenstrajelmer commented 3 months ago

There was new code added in https://github.com/Deltares/xugrid/pull/253 to merge partition chunks. This is really helpful, but I noticed in some cases it raises an additional error. This happens already because merged.chunks is called in the xu.merge_partitions() function. The example below raises "ValueError: Object has inconsistent chunks along dimension mesh2d_nEdges. This can be fixed by calling unify_chunks()."

An example:

import os
import xugrid as xu
import glob
import datetime as dt

chunks = {'time':1} # merging raises "ValueError: Object has inconsistent chunks along dimension mesh2d_nEdges. This can be fixed by calling unify_chunks()."
# chunks = 'auto' # merging works fine

dir_model = r"p:\11210284-011-nose-c-cycling\runs_fine_grid\B05_waq_2012_PCO2_ChlC_NPCratios_DenWat_stats_2023.01\B05_waq_2012_PCO2_ChlC_NPCratios_DenWat_stats_2023.01\DFM_OUTPUT_DCSM-FM_0_5nm_waq"
file_nc = os.path.join(dir_model, "DCSM-FM_0_5nm_waq_000*_map.nc")

file_nc_list = glob.glob(file_nc)
file_nc_list = file_nc_list[:2]

partitions = []
for iF, file_nc_one in enumerate(file_nc_list):
    uds_one = xu.open_dataset(file_nc_one, chunks=chunks)
    partitions.append(uds_one)

print(f'>> xu.merge_partitions() with {len(file_nc_list)} partition(s): ',end='')
dtstart = dt.datetime.now()
uds = xu.merge_partitions(partitions)
print(f'{(dt.datetime.now()-dtstart).total_seconds():.2f} sec')

I suggest to add merged = merged.unify_chunks() to xu.merge_partitions() so this is avoided.

I am not sure if there are downsides to calling unify_chunks() always. There seems to be no property to determine whether the chunks are inconsistent on a merged dataset. It is also possible to call unify_chunks() on all separate partitions, but that does not seem to make sense to me since the issue is with the merged result. Since there was no reproducible example included in https://github.com/Deltares/xugrid/issues/25 or any of the related issues, I added it in https://github.com/Deltares/xugrid/issues/252#issuecomment-2288080008. The files in that example were used to test my implementation in the example below, and unfortunately it gives a warning.

Testcase:

import glob
import xugrid as xu
import datetime as dt

file_nc = r"P:\dflowfm\projects\2021_verschilanalyse\modellen_2022_2023\dcsm_0.5nm\computations\hist\2013-2017_structure_ini_aangepast\2024.01_DFM_OUTPUT_DCSM-FM_0_5nm_goed/*0*_map.nc"
file_nc_list = glob.glob(file_nc)
chunks = {"time":1} # the chunks argument used in dfm_tools.open_partitioned_dataset()

print(f'>> xu.open_dataset() with {len(file_nc_list)} partition(s): ',end='')
dtstart = dt.datetime.now()
partitions = []
for ifile, file_one in enumerate(file_nc_list):
    print(ifile+1,end=' ')
    uds_one = xu.open_dataset(file_one, chunks=chunks)
    partitions.append(uds_one)
print(': ',end='')
print(f'{(dt.datetime.now()-dtstart).total_seconds():.2f} sec')

print(f'>> xu.merge_partitions() with {len(file_nc_list)} partition(s): ',end='')
dtstart = dt.datetime.now()
uds = xu.merge_partitions(partitions)
print(f'{(dt.datetime.now()-dtstart).total_seconds():.2f} sec')

Output:

>> xu.open_dataset() with 21 partition(s): 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 : 36.68 sec
>> xu.merge_partitions() with 21 partition(s): 3.78 sec
C:\Users\veenstra\Anaconda3\envs\dfm_tools_env\Lib\site-packages\xarray\core\computation.py:2303: PerformanceWarning: Increasing number of chunks by factor of 400
  _, chunked_data = chunkmanager.unify_chunks(*unify_chunks_args)

It seems that in order to call unify_chunks() one first has to optimize the chunks, but if you try to optimize inconsistent chunks you will get the error in https://github.com/Deltares/xugrid/issues/278 for which this PR was created. Is it possible to get out of this inconvenient loop?