pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.56k stars 1.07k forks source link

Assigning to DataTree.coords should support full paths #9485

Open shoyer opened 6 days ago

shoyer commented 6 days ago

What is your issue?

When assigning to a full path with DataTree.coords, I expected to assign a coordinate to a child, not the root.

Here's the behavior on main:

In [36]: tree = DataTree(Dataset(coords={'x': 0}), children={'child': DataTree()})

In [37]: tree.coords['/child/y'] = 2

In [38]: tree
Out[38]:
<xarray.DataTree>
Group: /
│   Dimensions:   ()
│   Coordinates:
│       x         int64 8B 0
│       /child/y  int64 8B 2
└── Group: /child

Instead, I expected the result of assigning directly to the child node (without the duiplicate inherited coordinate, of course):

In [40]: tree['/child'].coords['y'] = 2

In [41]: tree
Out[41]:
<xarray.DataTree>
Group: /
│   Dimensions:  ()
│   Coordinates:
│       x        int64 8B 0
└── Group: /child
        Dimensions:  ()
        Coordinates:
            x        int64 8B 0
            y        int64 8B 2
TomNicholas commented 6 days ago

Oh dear. The problem is that DataTreeCoordinates.update() assumes you only want to modify the local node. In fact the same bug exists for DataTree.update, for the same reason:

In [11]: dt = DataTree(Dataset(coords={'x': 0}), children={'child': DataTree()})

In [12]: dt.update({'child/y': Variable(data=2, dims=())})

In [13]: dt
Out[13]: 
<xarray.DataTree>
Group: /
│   Dimensions:  ()
│   Coordinates:
│       x        int64 8B 0
│   Data variables:
│       child/y  int64 8B 2
└── Group: /child

DataTree.__setitem__ is correct though:

In [15]: dt['child/y'] = Variable(data=2, dims=())

In [16]: dt
Out[16]: 
<xarray.DataTree>
Group: /
│   Dimensions:  ()
│   Coordinates:
│       x        int64 8B 0
└── Group: /child
        Dimensions:  ()
        Data variables:
            y        int64 8B 2

I think some refactoring to direct all these through the same codepath would help. Perhaps the general pattern should be more like:

class DataTree:
    def update(self, other: Mapping[str, Any]):
        for k, v in other.items():
            path = NodePath(k)
            node_to_update = self._walk_to(path)
            node_to_update._local_update({k: v})
TomNicholas commented 6 days ago

Note that merging https://github.com/pydata/xarray/pull/9378 would have prevented the variable being assigned, instead raising a ValueError.

TomNicholas commented 5 days ago

Note that merging https://github.com/pydata/xarray/pull/9378 would have prevented the variable being assigned, instead raising a ValueError.

This wasn't actually true - due to different code paths #9492 was also necessary to prevent this problem. I will refactor to make all updates go through the same code path in a later PR.