Deltares / xugrid

Xarray and unstructured grids
https://deltares.github.io/xugrid/
MIT License
61 stars 8 forks source link

Improve performance of `_get_topology()` by not accessing dataArray each time #285

Closed veenstrajelmer closed 3 weeks ago

veenstrajelmer commented 3 weeks ago

_get_topology loops over all data_vars: https://github.com/Deltares/xugrid/blob/3dee693763da1c4c0859a4f53ac38d4b99613a33/xugrid/ugrid/conventions.py#L183-L184

It seems that when accessing this via ds.variables.items() instead of ds[var], the dataarray is not accessed each time which saves a lot of time in case of many variables. The original method profiles like this: image

When replacing the _get_topology() code with [k for k, var in ds.variables.items() if var.attrs.get("cf_role") == "mesh_topology"] or [k for k in ds.data_vars if ds.variables[k].attrs.get("cf_role") == "mesh_topology"] (so adding only .variables), the profiler looks like this: image

So the timings drop from 16 seconds to <1 second in an example with 5 partitions. This will cause a tremendous improvement when using all 256 partitions of the dataset. Do note that this case covers a dataset with 2410 variables, so it will mostly improve performance of datasets with many variables. Some code to reproduce:

import os
import glob
import xugrid as xu
import xarray as xr
import datetime as dt

dir_model = r"p:\11210284-011-nose-c-cycling\runs_fine_grid\B05_waq_2012_PCO2_ChlC_NPCratios_DenWat_stats_2023.01\B05_waq_2012_PCO2_ChlC_NPCratios_DenWat_stats_2023.01\DFM_OUTPUT_DCSM-FM_0_5nm_waq"
file_nc_pat = os.path.join(dir_model, "DCSM-FM_0_5nm_waq_0*_map.nc")
file_nc_list_all = glob.glob(file_nc_pat)
file_nc_list = file_nc_list_all[:5]

print(f'>> xu.open_dataset() with {len(file_nc_list)} partition(s): ',end='')
dtstart = dt.datetime.now()
partitions = []
for iF, file_nc_one in enumerate(file_nc_list):
    print(iF+1,end=' ')
    ds_one = xr.open_mfdataset(file_nc_one, chunks="auto")
    uds_one = xu.core.wrap.UgridDataset(ds_one)
    partitions.append(uds_one)
print(': ',end='')
print(f'{(dt.datetime.now()-dtstart).total_seconds():.2f} sec')