xarray-contrib / datatree

WIP implementation of a tree-like hierarchical data structure for xarray.
https://xarray-datatree.readthedocs.io
Apache License 2.0
162 stars 43 forks source link

Parallelize map_over_subtree #252

Open Illviljan opened 11 months ago

Illviljan commented 11 months ago

I think there's some good opportunities to run map_over_subtree in parallel using dask.delayed.

Consider this example data:

import numpy as np
import xarray as xr
from datatree import DataTree

number_of_files = 25
number_of_groups = 20
number_of_variables = 2000

datasets = {}
for f in range(number_of_files):
    for g in range(number_of_groups):
        # Create random data:
        time = np.linspace(0, 50 + f, 100 + g)
        y = f * time + g

        # Create dataset:
        ds = xr.Dataset(
            data_vars={
                f"temperature_{g}{i}": ("time", y)
                for i in range(number_of_variables // number_of_groups)
            },
            coords={"time": ("time", time)},
        )  # .chunk()

        # Prepare for Datatree:
        name = f"file_{f}/group_{g}"
        datasets[name] = ds

dt = DataTree.from_dict(datasets)

# %% Interpolate to same time coordinate
new_time = np.linspace(0, 150, 50)
dt_interp = dt.interp(time=new_time)  
# Original 10s, with dask.delayed 6s
# If datasets were chunked: Original 34s, with dask.delayed 10s

Here's my modded map_over_subtree:

```python def map_over_subtree(func: Callable) -> Callable: """ Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. Applies a function to every dataset in one or more subtrees, returning new trees which store the results. The function will be applied to any non-empty dataset stored in any of the nodes in the trees. The returned trees will have the same structure as the supplied trees. `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any returned value that is one of these types will be stacked into a separate tree before returning all of them. The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named similarly, but all the output trees will have nodes named in the same way as the first tree passed. Parameters ---------- func : callable Function to apply to datasets with signature: `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. (i.e. func must accept at least one Dataset and return at least one Dataset.) Function will not be applied to any nodes without datasets. *args : tuple, optional Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets via .ds . **kwargs : Any Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets via .ds . Returns ------- mapped : callable Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at each node. See also -------- DataTree.map_over_subtree DataTree.map_over_subtree_inplace DataTree.subtree """ # TODO examples in the docstring # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? @functools.wraps(func) def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: """Internal function which maps func over every node in tree, returning a tree of the results.""" from .datatree import DataTree parallel = True if parallel: import dask func_ = dask.delayed(func) else: func_ = func all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ a for a in kwargs.values() if isinstance(a, DataTree) ] if len(all_tree_inputs) > 0: first_tree, *other_trees = all_tree_inputs else: raise TypeError("Must pass at least one tree object") for other_tree in other_trees: # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic check_isomorphic( first_tree, other_tree, require_names_equal=False, check_from_root=False, ) # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees # We don't know which arguments are DataTrees so we zip all arguments together as iterables # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return out_data_objects = {} args_as_tree_length_iterables = [ a.subtree if isinstance(a, DataTree) else repeat(a) for a in args ] n_args = len(args_as_tree_length_iterables) kwargs_as_tree_length_iterables = { k: v.subtree if isinstance(v, DataTree) else repeat(v) for k, v in kwargs.items() } for node_of_first_tree, *all_node_args in zip( first_tree.subtree, *args_as_tree_length_iterables, *list(kwargs_as_tree_length_iterables.values()), ): node_args_as_datasets = [ a.to_dataset() if isinstance(a, DataTree) else a for a in all_node_args[:n_args] ] node_kwargs_as_datasets = dict( zip( [k for k in kwargs_as_tree_length_iterables.keys()], [ v.to_dataset() if isinstance(v, DataTree) else v for v in all_node_args[n_args:] ], ) ) # Now we can call func on the data in this particular set of corresponding nodes results = ( func_(*node_args_as_datasets, **node_kwargs_as_datasets) if not node_of_first_tree.is_empty else None ) # TODO implement mapping over multiple trees in-place using if conditions from here on? out_data_objects[node_of_first_tree.path] = results if parallel: keys, values = dask.compute( [k for k in out_data_objects.keys()], [v for v in out_data_objects.values()], ) out_data_objects = {k: v for k, v in zip(keys, values)} # Find out how many return values we received num_return_values = _check_all_return_values(out_data_objects) # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees original_root_path = first_tree.path result_trees = [] for i in range(num_return_values): out_tree_contents = {} for n in first_tree.subtree: p = n.path if p in out_data_objects.keys(): if isinstance(out_data_objects[p], tuple): output_node_data = out_data_objects[p][i] else: output_node_data = out_data_objects[p] else: output_node_data = None # Discard parentage so that new trees don't include parents of input nodes relative_path = str( NodePath(p).relative_to(original_root_path) ) relative_path = "/" if relative_path == "." else relative_path out_tree_contents[relative_path] = output_node_data new_tree = DataTree.from_dict( out_tree_contents, name=first_tree.name, ) result_trees.append(new_tree) # If only one result then don't wrap it in a tuple if len(result_trees) == 1: return result_trees[0] else: return tuple(result_trees) return _map_over_subtree ```

I'm a little unsure how to get the parallel-argument down to map_over_subtree though?

TomNicholas commented 11 months ago

Good idea @Illviljan !

I'm a little unsure how to get the parallel-argument down to map_over_subtree though?

Do you actually need to pass it through at all? Couldn't you just do this:

def map_over_subtree(func: Callable, parallel=False) -> Callable:
    @functools.wraps(func)
    def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
        from .datatree import DataTree

        if parallel:
            import dask

or ideally just do this optimization automatically (if dask is installed I guess)?


I'm wondering how xarray normally does this optimization when you apply an operation to every data variable in a Dataset, for instance. Is it related to #196?

Illviljan commented 11 months ago

I tried a version with parallel as an argument but it isn't passed correctly via the normal methods: dt.interp(time=new_time, parallel=True) errors because it thinks parallel is a coordinate.

Maybe we could always use this optimization. Dask usually adds some overhead though, and I just haven't played around enough to know where that threshold is or if it is significant.

I'm wondering how xarray normally does this optimization when you apply an operation to every data variable in a Dataset, for instance. Is it related to #196?

I think the only place this trick is used is xr.open_mfdataset. Not sure why though, maybe most xarray methods predates dask.delayed? I also have a feeling my datasets with 2000+ variables is not the normal setup for most xarray users, so there's probably not been a need to optimize in the variable direction.

I don't fully understand all the changes in #196, I see that one as being able to trigger computation of all the dask arrays inside the DataArrays. My suggestion is earlier in that chain; setting up those chunked DataArrays in parallel.

TomNicholas commented 11 months ago

You have real datasets with 2000+ variables?!?

Now that I understand that this is not about triggering computation of dask arrays but about building the dask arrays in parallel, I'm less sure that this is a good idea.

I guess one way to look at it is through consistency: DataTree.map_over_subtree is very much a generalization of xarray's Dataset.map, just mapping over nested dictionaries of data variables instead of a single-level dict of data variables. As such I think that we should be consistent in how we treat these two implementations - either it makes sense to apply this optimization in both Dataset.map and DataTree.map_over_subtree, or to neither of them, because it's out-of-scope/too much overhead in both cases.

Illviljan commented 11 months ago

Yes, the example code is quite realistic. That's my type of datasets, and there's still always something missing...

Dataset.map looks very lightweight compared to Dataset.interp and DataTree.map_over_subtree handles both. Some functions are heavier and needs to be treated differently and therefore it's good to have the option of parallelization.

TomNicholas commented 11 months ago

Dataset.map looks very lightweight compared to Dataset.interp and DataTree.map_over_subtree handles both.

Are you saying that we already do some parallelization like this within Dataset.interp?

We discussed this in the xarray dev call today briefl. Stephan had a few comments, chiefly that he would be surprised if this gave significant speedup in most cases because of restrictions imposed by the GIL. Possibly once python removes the GIL we might want to revisit this question for all of xarray.