diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7e3badc7143..5b0aa37a460 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,9 @@ Bug Fixes - Ensure that ``keep_attrs='drop'`` and ``keep_attrs=False`` remove attrs from result, even when there is only one xarray object given to ``apply_ufunc`` (:issue:`10982` :pull:`10997`). By `Julia Signell `_. +- Fix ``DataTree`` bugs related to assigning nodes, variables and coordinates with +path like names (:issue:`9485`, :issue:`9490`, :issue:`9978`). + By `Ewan Short `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 9aa64a57ff2..bbb74770fcb 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1045,6 +1045,23 @@ def _drop_coords(self, coord_names): del self._data._node_coord_variables[name] del self._data._node_indexes[name] + def __setitem__(self, key: Hashable, value: Any) -> None: + """Set a coordinate, optionally at a path to a child node.""" + from xarray.core.treenode import NodePath + + # node_path, coord_name = NodePath(key)._get_components() + # target_node = self._data._get_target_object(node_path) + # target_node.coords[coord_name] = value + + # Check if key contains a forward slash (path-like access) + if isinstance(key, str) and "/" in key: + # Parse key as NodePath to enforce correct structure + node_path, coord_name = NodePath(key)._get_components() + target_node = self._data._get_target_object(node_path) + target_node.coords[coord_name] = value + else: + super().__setitem__(key, value) + def __delitem__(self, key: Hashable) -> None: if key in self: del self._data[key] # type: ignore[arg-type] # see https://github.com/pydata/xarray/issues/8836 diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index e079332780c..a1bb14cd9a4 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1040,16 +1040,17 @@ def __delitem__(self, key: str) -> None: else: raise KeyError(key) - @overload - def update(self, other: Dataset) -> None: ... - - @overload - def update(self, other: Mapping[Hashable, DataArray | Variable]) -> None: ... - - @overload - def update(self, other: Mapping[str, DataTree | DataArray | Variable]) -> None: ... - - def update( + def _get_target_object(self, node_path: NodePath | str) -> DataTree | DataArray: + """Helper function to get object, or create empty node if missing.""" + try: + target_object = self._get_item(node_path) + except KeyError: + # create new empty node, and all nodes along the path + self._set_item(node_path, DataTree(), new_nodes_along_path=True) + target_object = self._get_item(node_path) + return target_object + + def _update_local_node( self, other: ( Dataset @@ -1058,9 +1059,8 @@ def update( ), ) -> None: """ - Update this node's children and / or variables. - - Just like `dict.update` this is an in-place operation. + Helper to update the node's children and / or variables assuming there are no + path-like keys. """ new_children: dict[str, DataTree] = {} new_variables: CoercibleMapping @@ -1092,6 +1092,56 @@ def update( self._replace_node(data, children=merged_children) + @overload + def update(self, other: Dataset) -> None: ... + + @overload + def update( + self, other: Mapping[Hashable, Dataset | DataArray | Variable] + ) -> None: ... + + @overload + def update( + self, other: Mapping[str, DataTree | Dataset | DataArray | Variable] + ) -> None: ... + + def update( + self, + other: ( + Dataset + | Mapping[Hashable, Dataset | DataArray | Variable] + | Mapping[str, DataTree | Dataset | DataArray | Variable] + ), + ) -> None: + """ + Update this node's children and / or variables allowing for path-like keys. + + Just like `dict.update` this is an in-place operation. + """ + + # If other is a dataset assume we want to update just this node + if isinstance(other, Dataset): + self._update_local_node(other) + return + + # Otherwise divide other into groups with unique target nodes + items_by_node: dict[NodePath, dict[str, DataTree | DataArray | Variable]] = {} + for k, v in other.items(): + node_path, object_name = NodePath(k)._get_components() + if isinstance(v, Dataset): + # If v is a dataset, node_path/object_name should be a node + target_node = self._get_target_object(f"{node_path}/{object_name}") + # Update the node immediately + target_node._update_local_node(v) + else: + # Otherwise add to target node's items to update later as a group + items_by_node.setdefault(node_path, {}).update({object_name: v}) + + # Update each target node, creating if necessary + for node_path, node_other in items_by_node.items(): + target_node = self._get_target_object(node_path) + target_node._update_local_node(node_other) + def assign( self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any ) -> DataTree: diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 7eccf09088e..f5f6706bf3f 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -42,6 +42,18 @@ def absolute(self) -> Self: """Convert into an absolute path.""" return type(self)("/", *self.parts) + def _get_components(self) -> tuple[NodePath, str]: + """ + Check the NodePath has a non-empty name, which is required for object + assignment, e.g. tree['/path/to/node/object_name'] = value. + """ + if not self.name: + raise ValueError( + f"Invalid assignment path {self!r}. Assignment paths should have a " + "structure like 'path/to/node/object_name'." + ) + return self.parent, self.name + class TreeNode: """ @@ -172,12 +184,35 @@ def children(self, children: Mapping[str, Self]) -> None: old_children = self.children del self.children + + def handle_child(name, child): + """ + Decide whether to call _set_item or _set_parent. If the child name is a + path, we call ._set_item to make sure any intermediate nodes are also + created. When name is a path there will be nested calls of children due to + the setter decorator, and the fact ._set_item assigns to the children + attribute. + """ + if "/" in name: + # Path like node name. Create node at appropriate level of nesting. + self._set_item( + path=name, + item=child, + new_nodes_along_path=True, + allow_overwrite=True, + ) + else: + # Simple name, just call _set_parent directly. + child._set_parent(new_parent=self, child_name=name) + try: self._pre_attach_children(children) for name, child in children.items(): - child._set_parent(new_parent=self, child_name=name) + handle_child(name, child) self._post_attach_children(children) - assert len(self.children) == len(children) + # Check the children were created in the right place (probably redundant) + for path in children.keys(): + self._get_item(path) except Exception: # if something goes wrong then revert to previous children self.children = old_children @@ -206,7 +241,7 @@ def _check_children(children: Mapping[str, TreeNode]) -> None: if not isinstance(child, TreeNode): raise TypeError( f"Cannot add object {name}. It is of type {type(child)}, " - "but can only add children of type DataTree" + "but can only add children of type TreeNode" ) childid = id(child) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 0cd888f5782..3543483fb64 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -9,7 +9,7 @@ import pytest import xarray as xr -from xarray import DataArray, Dataset +from xarray import DataArray, Dataset, Variable from xarray.core.coordinates import DataTreeCoordinates from xarray.core.datatree import DataTree from xarray.core.treenode import NotFoundInTreeError @@ -82,6 +82,12 @@ def __deepcopy__(self, memo): dt3 = xr.DataTree.from_dict({"/": ds, "child": ds}) assert_identical(dt2, dt3) + def test_pathlike_children(self) -> None: + dt = xr.DataTree(children={"a/b/c": xr.DataTree(), "a/d": xr.DataTree()}) + assert "c" in dt["a/b"].children + assert dt["a/b/c"].path == "/a/b/c" + assert "d" in dt["a"].children + class TestFamilyTree: def test_dont_modify_children_inplace(self) -> None: @@ -337,6 +343,27 @@ def test_update(self) -> None: assert_equal(dt, expected) assert dt.groups == ("/", "/a") + def test_update_with_paths(self) -> None: + """Test that update() handles paths consistently with __setitem__.""" + dt1 = xr.DataTree(xr.Dataset(coords={"x": [0, 1, 2]})) + new_content: dict[str, DataTree | Dataset | DataArray | Variable] = {} + # Add a DataArray q to child node p + new_content = {"p/q": xr.DataArray(data=[10, 20, 30], dims=("x",))} + # Add a Variable r to child node p + new_content.update({"p/r": xr.Variable(data=2, dims=())}) + # Create new nodes s, t, and assign a Dataset to t + ds = xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [4, 5, 6])}) + new_content.update({"s/t": ds}) + dt1.update(new_content) + + dt2 = DataTree(xr.Dataset(coords={"x": [0, 1, 2]})) + for path, obj in new_content.items(): + # Add the new content with __setitem__ one item at a time + dt2[path] = obj + + # Both should produce the same result + assert_equal(dt1, dt2) + def test_update_new_named_dataarray(self) -> None: da = xr.DataArray(name="temp", data=[0, 50]) folder1 = DataTree(name="folder1") @@ -722,6 +749,27 @@ def test_modify(self) -> None: dt2 = DataTree(dataset=dt.coords) assert_identical(dt2.coords, dt.coords) + def test_setitem_coord_on_child_node(self) -> None: + """Test that coordinates can be assigned to child nodes using path-like syntax.""" + tree = DataTree(Dataset(coords={"x": 0}), children={"child": DataTree()}) + + # Assign coordinate to child node using path + tree.coords["/child/y"] = 2 + + # Verify coordinate is on child node, not root + assert "y" in tree["/child"].coords + assert "y" not in tree.coords + assert_array_equal(tree["/child"].coords["y"], 2) + + # Verify root still has its coordinate + assert "x" in tree.coords + assert_array_equal(tree.coords["x"], 0) + + # Test relative path + tree.coords["child/z"] = 3 + assert "z" in tree["/child"].coords + assert_array_equal(tree["/child"].coords["z"], 3) + def test_inherited(self) -> None: ds = Dataset( data_vars={