pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.63k stars 1.09k forks source link

map_over_datasets throws error on nodes without datasets #9693

Open dhruvbalwada opened 3 weeks ago

dhruvbalwada commented 3 weeks ago

map_over_datasets -- a way to compute over datatrees -- currently seems to try an operate even on nodes which contain no datasets, and consequently raises an error. This seems to be a new issue, and was not a problem when this function was called map_over_subtree, which was part of the experimental datatree versions.

An example to reproduce this problem is below:

## Generate datatree, using example from documentation
def time_stamps(n_samples, T):
    """Create an array of evenly-spaced time stamps"""
    return xr.DataArray(
        data=np.linspace(0, 2 * np.pi * T, n_samples), dims=["time"]
    )

def signal_generator(t, f, A, phase):
    """Generate an example electrical-like waveform"""
    return A * np.sin(f * t.data + phase)

time_stamps1 = time_stamps(n_samples=15, T=1.5)

time_stamps2 = time_stamps(n_samples=10, T=1.0)

voltages = xr.DataTree.from_dict(
    {
        "/oscilloscope1": xr.Dataset(
            {
                "potential": (
                    "time",
                    signal_generator(time_stamps1, f=2, A=1.2, phase=0.5),
                ),
                "current": (
                    "time",
                    signal_generator(time_stamps1, f=2, A=1.2, phase=1),
                ),
            },
            coords={"time": time_stamps1},
        ),
        "/oscilloscope2": xr.Dataset(
            {
                "potential": (
                    "time",
                    signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.2),
                ),
                "current": (
                    "time",
                    signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.7),
                ),
            },
            coords={"time": time_stamps2},
        ),
    }
)

## Write some function to add resistance
def add_resistance_only_do(dtree): 
    def calculate_resistance(ds):
        ds_new = ds.copy()

        ds_new['resistance'] = ds_new['potential']/ds_new['current']
        return ds_new 

    dtree = dtree.map_over_datasets(calculate_resistance)

    return dtree

def add_resistance_try(dtree): 
    def calculate_resistance(ds):
        ds_new = ds.copy()
        try:
            ds_new['resistance'] = ds_new['potential']/ds_new['current']
            return ds_new 
        except:
            return ds_new

    dtree = dtree.map_over_datasets(calculate_resistance)

    return dtree

Calling voltages = add_resistance_only_do(voltages) raises the error:

KeyError: "No variable named 'potential'. Variables on the dataset include []"
Raised whilst mapping function over node with path '.'

This can be easily resolved by putting try statements in (e.g. voltages = add_resistance_try(voltages)), but we know that Yoda would not recommend try (right @TomNicholas).

Can this be built in as a default feature of map_over_datasets? as many examples of datatree will have nodes without datasets.

shoyer commented 3 weeks ago

This was an intentional change, because a special case to skip empty nodes felt surprsing to me.

On the other hand, maybe it does make sense to skip nodes without datasets specifically for a method that maps over datasets (but not for a method that maps over nodes). So I'm open to changing this. The other option would be to add a new keyword argument to map_over_datasets for controlling this, something like skip_empty_nodes=True.

For what it's worth, the canonical way to write this today would be something like:

def add_resistance_try(dtree): 
    def calculate_resistance(ds):
        if not ds:
            return None
        ds_new = ds.copy()
        ds_new['resistance'] = ds_new['potential']/ds_new['current']
        return ds_new 

    dtree = dtree.map_over_datasets(calculate_resistance)
    return dtree
TomNicholas commented 3 weeks ago

Thanks for raising this @dhruvbalwada !

I would be in favor of changing this. It came up before for users and I'm not surprised it has come up almost immediately again.

I think it's reasonable for "map over datasets" to not map over a node where there is no dataset by default. The subtleties are with inherited variables and attrs. There are multiple issues on the old repo discussing this.

dcherian commented 3 weeks ago

The other option would be to add a new keyword argument to map_over_datasets for controlling this, something like skip_empty_nodes=True.

I like this idea with default False. With deep hierarchies it can be easy to miss that a node might be unexpectedly empty. So it'd be good to force users to opt in.

kmuehlbauer commented 3 weeks ago

I can see uses-cases for both skip_empty_nodes=False/skip_empty_nodes=True. So we wont make all users happy using one or the other default.

But I think we should not add that skip_empty_nodes-kwarg at all. Instead we could encourage users to work with solutions along @shoyer's above suggestion. In more complex scenarios users will need such solutions anyway, since their functions might only work on dedicated nodes as their tree layout might differ significantly and nodes wont be equivalent in terms of their content.

To assist users with that task xarray could provide the same functionality the OP is looking for using a simple decorator, (Update: now tested, finally):

import functools
def skip_empty_nodes(func):
    @functools.wraps(func)
    def _func(ds, *args, **kwargs):
        if not ds:
            return ds
        return func(ds, *args, **kwargs)
    return _func

def add_resistance_try(dtree):
    @skip_empty_nodes
    def calculate_resistance(ds):
        ds_new = ds.copy()
        ds_new['resistance'] = ds_new['potential']/ds_new['current']
        return ds_new 

    dtree = dtree.map_over_datasets(calculate_resistance)
    return dtree

voltages = add_resistance_try(voltages)

Anyway, if the kwarg-solution is preferred, I'm opting for skip_empty_nodes=False.

shoyer commented 3 weeks ago

I don't think we need extensive helper functions or options in map_over_datasets. It's a convenience function, which is why I'm OK skipping empty nodes by default.

For cases where users need control, they can just iterate over DataTree.subtree_with_keys or xarray.group_subtrees() themselves.

kmuehlbauer commented 3 weeks ago

Fine with that, too. Are Datasets with only attrs considered empty?

shoyer commented 3 weeks ago

Fine with that, too. Are Datasets with only attrs considered empty?

There are a few different edge cases:

The original map_over_subtrees had special logic to propagate forward attributes only for empty nodes, without calling the mapped over function. That seems reasonable to me.

I'm not sure whether or not to call the mapped over function for nodes that only define coordinates. Certainly I would not blindly copy coordinates from otherwise empty nodes onto the result, because those coordinates may no longer be relevant on the result.

kmuehlbauer commented 3 weeks ago

Thanks @shoyer for the details. Good to see that there are solutions for many use-cases already built-in or available via external helper functions.

I'm diverting a bit from the issue now. I've had to do this kind of wrapping to feed kwargs to my mapping function. What is the canonical way to feed kwargs to map_over_datasets? I should open a separate issue for that.

shoyer commented 3 weeks ago

I'm diverting a bit from the issue now. I've had to do this kind of wrapping to feed kwargs to my mapping function. What is the canonical way to feed kwargs to map_over_datasets? I should open a separate issue for that.

You can pass in a helper function or use functools.partial. We could also add a kwargs argument like xarray.apply_ufunc.

keewis commented 3 weeks ago

or use functools.wraps

shouldn't that be functools.partial?