xarray-contrib / datatree

WIP implementation of a tree-like hierarchical data structure for xarray.
https://xarray-datatree.readthedocs.io
Apache License 2.0
166 stars 41 forks source link

API for filtering / subsetting #79

Closed TomNicholas closed 3 months ago

TomNicholas commented 2 years ago

So far we've only really implemented dictionary-like get/setitem syntax, but we should add a variety of other ways to select nodes from a tree too. Here are some suggestions:

class DataTree:
    ...

    def __getitem__(self, key: str) -> DataTree | DataArray:
        """
        Accepts node/variable names, or file-like paths to nodes/variables (inc. '../var').

        (Also needs to accommodate indexing somehow.)
        """
        ...

    def subset(self, keys: Sequence[str]) -> DataTree:
        """
        Return new tree containing only nodes with names matching keys.

        (Could probably be combined with `__getitem__`. 
        Also unsure what the return type should be.)
        """
        ...

    @property
    def subtree(self) -> Iterator[DataTree]:
        """An iterator over all nodes in this tree, including both self and all descendants."""
        ...

    def filter(self, filterfunc: Callable) -> Iterator[DataTree]:
        """Filters subtree by returning only nodes for which `filterfunc(node)` is True."""
        ...

Are there other types of access that we're missing here? Filtering by regex match? Getting nodes where at least one part of the path matches ("tag-like" access)? Glob?

TomNicholas commented 2 years ago

@oriolabril would these types of functions be sufficient for ArViz's usecases you think? From https://github.com/arviz-devs/arviz/issues/2015#issuecomment-1106957255:

dt[["posterior", "posterior_predictive"]] is not possible

getting a subset of the datatree that consists of multiple groups

This is what I'm suggesting subset do, or __getitem__.

applying a function to the variable x that is present in 3 out of 5 groups of the datatree.

I'm imagining enabling that via

dt.filter(lambda node: 'x' in node.variables).map_over_subtree(func)

Or we could potentially add an optional filterfunc argument to map_over_subtree.

dcherian commented 2 years ago

One possible point of (approximate) alignment with Xarray API is this issue: https://github.com/pydata/xarray/issues/3894 for selecting using an iterable of variable names. This seems analogous to selecting nodes using subset

TomNicholas commented 2 years ago

I had not seen that issue, thanks @dcherian

OriolAbril commented 2 years ago

I think that would cover everything, but I'll try to think of examples so that we can also have things to test on.

We could also provide functions in datatree/xarray/arviz to act as filterfunc for common cases. My main question when thinking about using filter is storing the results back. I guess a merge would do it? With some renaming happening in the process maybe. It will probably be best to discuss with some examples.

TomNicholas commented 2 years ago

My main question when thinking about using filter is storing the results back.

Yes that's the tricky bit, because if you want to return a tree then you might need to retain nodes for which filterfunc(node)=False in order to still have a valid tree structure afterwards...

For example:

def name_is_lowercase(node)
   return node.name == node.name.lower() 

root = DataTree("a")
child = DataTree(parent=root, name="B")
grandchild = DataTree(parent=child, name="c") 

root.filter(name_is_lowercase)

This would return nodes "a" and "c", but it couldn't automatically reconstruct them into a tree without also preserving node "B".

If .filter just returned an iterator of nodes then you wouldn't need to be able to rebuild a tree, but this might not be most convenient for the user. This is why I would like to build these functions with some desired usage patterns in mind.

TomNicholas commented 1 year ago

I added a method to filter nodes based on some condition in #185

dcherian commented 1 year ago

I've routinely wanted something that says select these variable names from all nodes.

This is way too much typing for that:

dailies.map_over_subtree(lambda n: n[["KT", "eps", "chi"]])

Perhaps a DataTree.subset_nodes?

OriolAbril commented 1 year ago

Finally started using DataTree intensively. I also find I am using map_over_subtree more often than I would like. And not only to subset some variables, also for use with .sel, .mean...

How would you feel about an accessor or something of the sort (.tree or .treemap for example) that exposes all the methods (or a subset of commonly used ones) via map_over_subtree?

dt.map_over_subtree(lambda node: node.sel(dim="label"))
# would become
dt.tree.sel(dim="label")

# and the same for .map, .drop_sel, .mean and others
TomNicholas commented 3 months ago

Closing in favour of https://github.com/pydata/xarray/issues/9342