pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.62k stars 1.09k forks source link

datatree: Tree-aware dataset handling/selection #9346

Open keewis opened 3 months ago

keewis commented 3 months ago

What is your issue?

I'm looking for a good way to apply a function to a subset of nodes that share some common characteristics encoded in the subtree path.

Imagine the following data tree

import xarray as xr
import datatree
from datatree import map_over_subtree

dt = datatree.DataTree.from_dict({
    'control/lowRes' : xr.Dataset({'z':(('x'),[0,1,2])}),
    'control/highRes' : xr.Dataset({'z':(('x'),[0,1,2,3,4,5])}),
    'plus4K/lowRes' : xr.Dataset({'z':(('x'),[0,1,2])}),
    'plus4K/highRes' : xr.Dataset({'z':(('x'),[0,1,2,3,4,5])})
})

To apply a function to all control or all plus4K nodes is straight forward by just selecting the specific subtree, e.g. dt['control']. However, in case all lowRes dataset should be manipulated this becomes more elaborative and I wonder what the best approach would be.

* `dt['control/lowRes','plus4K/lowRes']` is not yet implemented and would also be complex for large data trees

* `dt['*/lowRes']` could be one idea to make the subtree selection more straight forward, where `*` is a wildcard

* `dt.search(regex)` could make this even more general

Currently, I use the @map_over_subtree decorator, which also has some limitations as the function does not know its tree origin (as noted in the code) and it needs to be inferred from the dataset itself, which is sometimes possible (here the length of the dataset) but does not need to be always the case.

@map_over_subtree
def resolution_specific_func(ds):
    if len(ds.x) == 3:
        ds = ds.z*2
    elif len(ds.x) == 6:
        ds = ds.z*4
    return ds

z= resolution_specific_func(dt)

I do not know how the tree information could be passed through the decorator, but maybe it is okay if the DatasetView class has an additional property (e.g. _path) that could be filled with dt.path during the call of DatasetView._from_node()?. This would lead to

@map_over_subtree
def resolution_specific_func(ds):
    if 'lowRes' in ds._path:
        ds = ds.z*2
    if 'highRes' in ds._path:
        ds = ds.z*4
    return ds

and would allow for tree-aware manipulation of the datasets.

What do you think? Happy to open a PR if this makes sense.

Originally posted by @observingClouds in https://github.com/xarray-contrib/datatree/issues/254#issue-1835784457