xarray-contrib / datatree

WIP implementation of a tree-like hierarchical data structure for xarray.
https://xarray-datatree.readthedocs.io
Apache License 2.0
169 stars 44 forks source link

Add dask.delayed to map_over_subtree #253

Open Illviljan opened 1 year ago

Illviljan commented 1 year ago

Main:

%timeit dt.interp(time=new_time)
49.9 s ± 1.3 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

This PR:

%timeit dt.interp(time=new_time)
16.7 s ± 297 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```python from functools import partial import time import numpy as np import xarray as xr from datatree import DataTree import datatree import dask import matplotlib.pyplot as plt number_of_files = 25 number_of_groups = 20 number_of_variables = 2000 def create_datatree(number_of_files, number_of_groups, number_of_variables): datasets = {} for f in range(number_of_files): for g in range(number_of_groups): # Create random data: time = np.linspace(0, 50 + f, 100 + g) y = f * time + g # Create dataset: ds = xr.Dataset( data_vars={ f"temperature_{g}{i}": ("time", y) for i in range(number_of_variables // number_of_groups) }, coords={"time": ("time", time)}, ).chunk() # Prepare for Datatree: name = f"file_{f}/group_{g}" datasets[name] = ds dt = DataTree.from_dict(datasets) return dt # %% Interpolate to same time coordinate dt = create_datatree(number_of_files, number_of_groups, number_of_variables) new_time = np.linspace(0, 150, 50) datatree.mapping._map_over_subtree_kwargs.update(parallel=True) dt_interp = dt.interp(time=new_time) # %timeit dt.interp(time=new_time) # %% Usages of mab_over_subtree # Built in: datatree.mapping._map_over_subtree_kwargs.update(parallel=True) dt_interp = dt.interp(time=new_time) # Decorator; @partial(datatree.map_over_subtree, parallel=True) def mean(ds): return ds.mean() mean(dt) # Function: datatree.map_over_subtree(np.mean, parallel=True)(dt) # %% Time a file sweep new_time = np.linspace(0, 150, 50) def time_many_files(n=50, step=5): times = {} for f in range(1, n, step): dt = create_datatree(f, number_of_groups, number_of_variables) start = time.time() dt_interp = dt.interp(time=new_time) end = time.time() diff = end - start print(f"{f} files took {diff:0.5} seconds.") times[f] = end - start return times print("Sequential:") datatree.mapping._map_over_subtree_kwargs.update(parallel=False) times_seq = time_many_files() print("Parallel:") datatree.mapping._map_over_subtree_kwargs.update(parallel=True) times_par = time_many_files() plt.figure() fig, ax = plt.subplots(1, 1) ax.plot(list(times_seq.keys()), list(times_seq.values()), label="Sequential") ax.plot(list(times_par.keys()), list(times_par.values()), label="Parallel") ax.set_title( ( "Time to interpolate datatree\n" f"Each file has {number_of_variables} variables and {number_of_groups} groups" ) ) ax.set_ylabel("Time [s]") ax.set_xlabel("Number of files") ax.legend() ```