pydata / xarray

N-D labeled arrays and datasets in Python
Apache License 2.0
3.63k stars 1.09k forks source link

children nodes are deepcopied when initializing a DataTree #9683

Closed OriolAbril closed 3 weeks ago

OriolAbril commented 3 weeks ago

What happened?

I have xarray objects that contain object dtype numpy arrays whose elements can't be deepcopied. This has never been an issue, not even when using xarray-datatree and I saw in that children objects are shallow copied but they seem to be deepcopied.

What did you expect to happen?

A DataTree is created without problem even when its children nodes store non deepcopiable objects as array elements in their variables.

Minimal Complete Verifiable Example

# minimal setup
import xarray as xr
from copy import copy, deepcopy

class NoDeepCopy:
    def __deepcopy__(self, memo):
        raise TypeError("Class can't be deepcopied")

# check things do what we expect
example = NoDeepCopy()
# works
# Raises TypeError: Class can't be deepcopied

# On to xarray use. All of these work correctly:
da = xr.DataArray(NoDeepCopy())
ds = xr.Dataset({"var": da})
dt1 = xr.DataTree(ds)
dt2 = xr.DataTree.from_dict({"/": ds})

# However, none of these work, they all end up triggering the `__deepcopy__`
# method of the `NoDeepCopy` class
dt3 = xr.DataTree(ds, children={"child": dt1})
dt4 = xr.DataTree.from_dict({"child": ds})
dt5 = xr.DataTree()
dt5.children = {"child": xr.DataTree(ds)}

MVCE confirmation

Relevant log output

TypeError                                 Traceback (most recent call last)
Cell In[9], line 1
----> 1 dt4 = xr.DataTree.from_dict({"child": ds})
      2 dt4

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in DataTree.from_dict(cls, d, name)
   1196         else:
   1197             raise TypeError(f"invalid values: {data}")
-> 1198         obj._set_item(
   1199             path,
   1200             new_node,
   1201             allow_overwrite=False,
   1202             new_nodes_along_path=True,
   1203         )
   1205 # TODO: figure out why mypy is raising an error here, likely something
   1206 # to do with the return type of Dataset.copy()
   1207 return obj

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in TreeNode._set_item(self, path, item, new_nodes_along_path, allow_overwrite)
    650         raise KeyError(f"Already a node object at path {path}")
    651 else:
--> 652     current_node._set(name, item)

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in DataTree._set(self, key, val)
    942     new_node = val.copy(deep=False)
    943 = key
--> 944     new_node._set_parent(new_parent=self, child_name=key)
    945 else:
    946     if not isinstance(val, DataArray | Variable):
    947         # accommodate other types that can be coerced into Variables

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in TreeNode._set_parent(self, new_parent, child_name)
    113 self._check_loop(new_parent)
    114 self._detach(old_parent)
--> 115 self._attach(new_parent, child_name)

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in TreeNode._attach(self, parent, child_name)
    147 if child_name is None:
    148     raise ValueError(
    149         "To directly set parent, child needs a name, but child is unnamed"
    150     )
--> 152 self._pre_attach(parent, child_name)
    153 parentchildren = parent._children
    154 assert not any(
    155     child is self for child in parentchildren
    156 ), "Tree is corrupt."

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in DataTree._pre_attach(self, parent, name)
    526 node_ds = self.to_dataset(inherit=False)
    527 parent_ds = parent._to_dataset_view(rebuild_dims=False, inherit=True)
--> 528 check_alignment(path, node_ds, parent_ds, self.children)
    529 _deduplicate_inherited_coordinates(self, parent)

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in check_alignment(path, node_ds, parent_ds, children)
    147 if parent_ds is not None:
    148     try:
--> 149         align(node_ds, parent_ds, join="exact")
    150     except ValueError as e:
    151         node_repr = _indented(_without_header(repr(node_ds)))

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in align(join, copy, indexes, exclude, fill_value, *objects)
    687 """
    688 Given any number of Dataset and/or DataArray objects, returns new
    689 objects with aligned indexes and dimension sizes.
    874 """
    875 aligner = Aligner(
    876     objects,
    877     join=join,
    881     fill_value=fill_value,
    882 )
--> 883 aligner.align()
    884 return aligner.results

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in Aligner.align(self)
    581     self.results = self.objects
    582 else:
--> 583     self.reindex_all()

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in Aligner.reindex_all(self)
    557 def reindex_all(self) -> None:
--> 558     self.results = tuple(
    559         self._reindex_one(obj, matching_indexes)
    560         for obj, matching_indexes in zip(
    561             self.objects, self.objects_matching_indexes, strict=True
    562         )
    563     )

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in <genexpr>(.0)
    557 def reindex_all(self) -> None:
    558     self.results = tuple(
--> 559         self._reindex_one(obj, matching_indexes)
    560         for obj, matching_indexes in zip(
    561             self.objects, self.objects_matching_indexes, strict=True
    562         )
    563     )

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in Aligner._reindex_one(self, obj, matching_indexes)
    544 new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes)
    545 dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes)
--> 547 return obj._reindex_callback(
    548     self,
    549     dim_pos_indexers,
    550     new_variables,
    551     new_indexes,
    552     self.fill_value,
    553     self.exclude_dims,
    554     self.exclude_vars,
    555 )

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in Dataset._reindex_callback(self, aligner, dim_pos_indexers, variables, indexes, fill_value, exclude_dims, exclude_vars)
   3567         reindexed = self._overwrite_indexes(new_indexes, new_variables)
   3568     else:
-> 3569         reindexed = self.copy(deep=aligner.copy)
   3570 else:
   3571     to_reindex = {
   3572         k: v
   3573         for k, v in self.variables.items()
   3574         if k not in variables and k not in exclude_vars
   3575     }

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in Dataset.copy(self, deep, data)
   1277 def copy(self, deep: bool = False, data: DataVars | None = None) -> Self:
   1278     """Returns a copy of this dataset.
   1280     If `deep=True`, a deep copy is made of each of the component variables.
   1372     pandas.DataFrame.copy
   1373     """
-> 1374     return self._copy(deep=deep, data=data)

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in Dataset._copy(self, deep, data, memo)
   1408         variables[k] = index_vars[k]
   1409     else:
-> 1410         variables[k] = v._copy(deep=deep, data=data.get(k), memo=memo)
   1412 attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
   1413 encoding = (
   1414     copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding)
   1415 )

File ~/bin/miniforge3/envs/arviz/lib/python3.11/site-packages/xarray/core/, in Variable._copy(self, deep, data, memo)
    937         ndata = indexing.MemoryCachedArray(data_old.array)  # type: ignore[assignment]
    939     if deep:
--> 940         ndata = copy.deepcopy(ndata, memo)
    942 else:
    943     ndata = as_compatible_data(data)

File ~/bin/miniforge3/envs/arviz/lib/python3.11/, in deepcopy(x, memo, _nil)
    151 copier = getattr(x, "__deepcopy__", None)
    152 if copier is not None:
--> 153     y = copier(memo)
    154 else:
    155     reductor = dispatch_table.get(cls)

File ~/bin/miniforge3/envs/arviz/lib/python3.11/, in deepcopy(x, memo, _nil)
    151 copier = getattr(x, "__deepcopy__", None)
    152 if copier is not None:
--> 153     y = copier(memo)
    154 else:
    155     reductor = dispatch_table.get(cls)

Cell In[2], line 3, in NoDeepCopy.__deepcopy__(self, memo)
      2 def __deepcopy__(self, memo):
----> 3     raise TypeError("Class can't be deepcopied")

TypeError: Class can't be deepcopied

Anything else we need to know?

I added the traceback of the dt4 case, not sure which would be the most informative. They are not all exactly the same in the beginning but they are after TreeNode._set_parent, then on to TreeNode._attach, DataTree._pre_attach... until __deepcopy__. Any of the different starts for the tracebacks should be reproducible copy pasting the example though.


INSTALLED VERSIONS ------------------ commit: None python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0] python-bits: 64 OS: Linux OS-release: 5.14.21-150500.55.83-default machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: ca_ES.UTF-8 LOCALE: ('ca_ES', 'UTF-8') libhdf5: 1.14.3 libnetcdf: None xarray: 2024.10.0 pandas: 2.2.3 numpy: 1.26.4 scipy: 1.14.1 netCDF4: None pydap: None h5netcdf: 1.3.0 h5py: 3.11.0 zarr: 2.16.1 cftime: None nc_time_axis: None iris: None bottleneck: None dask: 2024.5.0 distributed: 2024.5.0 matplotlib: 3.8.2 cartopy: None seaborn: None numbagg: None fsspec: 2024.3.1 cupy: None pint: None sparse: None flox: None numpy_groupies: None setuptools: 68.2.2 pip: 23.3.2 conda: None pytest: 7.4.3 mypy: None IPython: 8.18.1 sphinx: 7.2.6
shoyer commented 3 weeks ago

Thanks for the very clear bug report!

This should be fixed by #9684.