From 95a1a3b808c2de03fc2c59a6c91f8b0db60dee87 Mon Sep 17 00:00:00 2001 From: Ewan Short Date: Thu, 11 Dec 2025 11:14:07 +1100 Subject: [PATCH 1/3] fix DataTree path like assignment GH9485 GH9490 GH9978 --- doc/whats-new.rst | 3 ++ xarray/core/coordinates.py | 17 ++++++++++ xarray/core/datatree.py | 54 ++++++++++++++++++++++++++++--- xarray/core/treenode.py | 41 ++++++++++++++++++++++-- xarray/tests/test_datatree.py | 60 +++++++++++++++++++++++++++++++++++ 5 files changed, 168 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7e3badc7143..5b0aa37a460 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,9 @@ Bug Fixes - Ensure that ``keep_attrs='drop'`` and ``keep_attrs=False`` remove attrs from result, even when there is only one xarray object given to ``apply_ufunc`` (:issue:`10982` :pull:`10997`). By `Julia Signell `_. +- Fix ``DataTree`` bugs related to assigning nodes, variables and coordinates with +path like names (:issue:`9485`, :issue:`9490`, :issue:`9978`). + By `Ewan Short `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 9aa64a57ff2..cb80362bae4 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1045,6 +1045,23 @@ def _drop_coords(self, coord_names): del self._data._node_coord_variables[name] del self._data._node_indexes[name] + def __setitem__(self, key: Hashable, value: Any) -> None: + """Set a coordinate, optionally at a path to a child node.""" + from xarray.core.treenode import NodePath + + # node_path, coord_name = NodePath(key)._get_components() + # target_node = self._data._get_target_node(node_path) + # target_node.coords[coord_name] = value + + # # # Check if key contains a forward slash (path-like access) + if isinstance(key, str) and "/" in key: + # Parse key as NodePath to enforce correct structure + node_path, coord_name = NodePath(key)._get_components() + target_node = self._data._get_target_node(node_path) + target_node.coords[coord_name] = value + else: + super().__setitem__(key, value) + def __delitem__(self, key: Hashable) -> None: if key in self: del self._data[key] # type: ignore[arg-type] # see https://github.com/pydata/xarray/issues/8836 diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index e079332780c..277d7c6f018 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1049,7 +1049,17 @@ def update(self, other: Mapping[Hashable, DataArray | Variable]) -> None: ... @overload def update(self, other: Mapping[str, DataTree | DataArray | Variable]) -> None: ... - def update( + def _get_target_node(self, node_path: NodePath) -> DataTree: + """Helper function to get node or create if missing.""" + try: + target_node = self._get_item(node_path) + except KeyError: + # create new nodes along the path + self._set_item(node_path, DataTree(), new_nodes_along_path=True) + target_node = self._get_item(node_path) + return target_node + + def _update_local_node( self, other: ( Dataset @@ -1058,9 +1068,8 @@ def update( ), ) -> None: """ - Update this node's children and / or variables. - - Just like `dict.update` this is an in-place operation. + Helper to update the node's children and / or variables assuming there are no + path-like keys. """ new_children: dict[str, DataTree] = {} new_variables: CoercibleMapping @@ -1092,6 +1101,43 @@ def update( self._replace_node(data, children=merged_children) + def update( + self, + other: ( + Dataset + | Mapping[Hashable, DataArray | Dataset | Variable] + | Mapping[str, DataTree | Dataset | DataArray | Variable] + ), + ) -> None: + """ + Update this node's children and / or variables allowing for path-like keys. + + Just like `dict.update` this is an in-place operation. + """ + + # If other is a dataset assume we want to update just this node + if isinstance(other, Dataset): + self._update_local_node(other) + return + + # Otherwise divide other into groups with unique target nodes + items_by_node = {} + for k, v in other.items(): + node_path, object_name = NodePath(k)._get_components() + if isinstance(v, Dataset): + # If v is a dataset, node_path/object_name should be a node + target_node = self._get_target_node(f"{node_path}/{object_name}") + # Update the node immediately + target_node._update_local_node(v) + else: + # Otherwise add to target nodes items to update later as a group + items_by_node.setdefault(node_path, {}).update({object_name: v}) + + # Update each target node, creating if necessary + for node_path, node_other in items_by_node.items(): + target_node = self._get_target_node(node_path) + target_node._update_local_node(node_other) + def assign( self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any ) -> DataTree: diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 7eccf09088e..464df975a6d 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -42,6 +42,18 @@ def absolute(self) -> Self: """Convert into an absolute path.""" return type(self)("/", *self.parts) + def _get_components(self) -> None: + """ + Check the NodePath has a non-empty name, which is required for object + assignment, e.g. tree['/path/to/node/object_name'] = value. + """ + if not self.name: + raise ValueError( + f"Invalid assignment path {self!r}. Assignment paths should have a " + "structure like 'path/to/node/object_name'." + ) + return self.parent, self.name + class TreeNode: """ @@ -172,12 +184,35 @@ def children(self, children: Mapping[str, Self]) -> None: old_children = self.children del self.children + + def handle_child(name, child): + """ + Decide whether to call _set_item or _set_parent. If the child name is a + path, we call ._set_item to make sure any intermediate nodes are also + created. When name is a path there will be nested calls of children due to + the setter decorator, and the fact ._set_item assigns to the children + attribute. + """ + if "/" in name: + # Path like node name. Create node at appropriate level of nesting. + self._set_item( + path=name, + item=child, + new_nodes_along_path=True, + allow_overwrite=True, + ) + else: + # Simple name, just call _set_parent directly. + child._set_parent(new_parent=self, child_name=name) + try: self._pre_attach_children(children) for name, child in children.items(): - child._set_parent(new_parent=self, child_name=name) + handle_child(name, child) self._post_attach_children(children) - assert len(self.children) == len(children) + # Check that all the children were created in the right place + for path in children.keys(): + self._get_item(path) except Exception: # if something goes wrong then revert to previous children self.children = old_children @@ -206,7 +241,7 @@ def _check_children(children: Mapping[str, TreeNode]) -> None: if not isinstance(child, TreeNode): raise TypeError( f"Cannot add object {name}. It is of type {type(child)}, " - "but can only add children of type DataTree" + "but can only add children of type TreeNode" ) childid = id(child) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 0cd888f5782..b9477643fee 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -82,6 +82,12 @@ def __deepcopy__(self, memo): dt3 = xr.DataTree.from_dict({"/": ds, "child": ds}) assert_identical(dt2, dt3) + def test_pathlike_children(self) -> None: + dt = xr.DataTree(children={"a/b/c": xr.DataTree(), "a/d": xr.DataTree()}) + assert "c" in dt["a/b"].children + assert dt["a/b/c"].path == "/a/b/c" + assert "d" in dt["a"].children + class TestFamilyTree: def test_dont_modify_children_inplace(self) -> None: @@ -337,6 +343,39 @@ def test_update(self) -> None: assert_equal(dt, expected) assert dt.groups == ("/", "/a") + def test_update_with_paths(self) -> None: + """Test that update() handles paths consistently with __setitem__.""" + ds = Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [4, 5, 6])}) + dt1 = DataTree() + # Add a DataArray x to node child + new_content = {"child/x": xr.DataArray(data=[10, 20], dims=("z",))} + # Add a variable y to node child + new_content.update({"child/y": xr.Variable(data=2, dims=())}) + # Create a new node z with data ds and parent step_child + new_content.update({"step_child/z": ds}) + dt1.update(new_content) + + dt2 = DataTree() + for path, obj in new_content.items(): + # Add the new content with __setitem__ + dt2[path] = obj + + # Both should produce the same result + assert_equal(dt1, dt2) + + # Test with multiple path assignments + dt3 = DataTree() + dt3.update( + { + "a/x": xr.Variable(data=1, dims=()), + "a/y": xr.Variable(data=2, dims=()), + "b/z": xr.Variable(data=3, dims=()), + } + ) + assert "x" in dt3["/a"].data_vars + assert "y" in dt3["/a"].data_vars + assert "z" in dt3["/b"].data_vars + def test_update_new_named_dataarray(self) -> None: da = xr.DataArray(name="temp", data=[0, 50]) folder1 = DataTree(name="folder1") @@ -722,6 +761,27 @@ def test_modify(self) -> None: dt2 = DataTree(dataset=dt.coords) assert_identical(dt2.coords, dt.coords) + def test_setitem_coord_on_child_node(self) -> None: + """Test that coordinates can be assigned to child nodes using path-like syntax.""" + tree = DataTree(Dataset(coords={"x": 0}), children={"child": DataTree()}) + + # Assign coordinate to child node using path + tree.coords["/child/y"] = 2 + + # Verify coordinate is on child node, not root + assert "y" in tree["/child"].coords + assert "y" not in tree.coords + assert_array_equal(tree["/child"].coords["y"], 2) + + # Verify root still has its coordinate + assert "x" in tree.coords + assert_array_equal(tree.coords["x"], 0) + + # Test relative path + tree.coords["child/z"] = 3 + assert "z" in tree["/child"].coords + assert_array_equal(tree["/child"].coords["z"], 3) + def test_inherited(self) -> None: ds = Dataset( data_vars={ From f9c7538c9c30ca4a221ab1568f6457e73bce1be7 Mon Sep 17 00:00:00 2001 From: Ewan Short Date: Thu, 11 Dec 2025 11:29:21 +1100 Subject: [PATCH 2/3] minor comment format fix --- xarray/core/coordinates.py | 2 +- xarray/core/treenode.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index cb80362bae4..93cdf947150 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1053,7 +1053,7 @@ def __setitem__(self, key: Hashable, value: Any) -> None: # target_node = self._data._get_target_node(node_path) # target_node.coords[coord_name] = value - # # # Check if key contains a forward slash (path-like access) + # Check if key contains a forward slash (path-like access) if isinstance(key, str) and "/" in key: # Parse key as NodePath to enforce correct structure node_path, coord_name = NodePath(key)._get_components() diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 464df975a6d..bea2e30be90 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -210,7 +210,7 @@ def handle_child(name, child): for name, child in children.items(): handle_child(name, child) self._post_attach_children(children) - # Check that all the children were created in the right place + # Check the children were created in the right place (probably redundant) for path in children.keys(): self._get_item(path) except Exception: From c99c3033f26bc34d85e3f1941f14643f986ecadf Mon Sep 17 00:00:00 2001 From: Ewan Short Date: Thu, 11 Dec 2025 12:52:07 +1100 Subject: [PATCH 3/3] remove redundant test and fix mypy errors --- xarray/core/coordinates.py | 4 ++-- xarray/core/datatree.py | 44 +++++++++++++++++++---------------- xarray/core/treenode.py | 2 +- xarray/tests/test_datatree.py | 36 ++++++++++------------------ 4 files changed, 39 insertions(+), 47 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 93cdf947150..bbb74770fcb 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1050,14 +1050,14 @@ def __setitem__(self, key: Hashable, value: Any) -> None: from xarray.core.treenode import NodePath # node_path, coord_name = NodePath(key)._get_components() - # target_node = self._data._get_target_node(node_path) + # target_node = self._data._get_target_object(node_path) # target_node.coords[coord_name] = value # Check if key contains a forward slash (path-like access) if isinstance(key, str) and "/" in key: # Parse key as NodePath to enforce correct structure node_path, coord_name = NodePath(key)._get_components() - target_node = self._data._get_target_node(node_path) + target_node = self._data._get_target_object(node_path) target_node.coords[coord_name] = value else: super().__setitem__(key, value) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 277d7c6f018..a1bb14cd9a4 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1040,24 +1040,15 @@ def __delitem__(self, key: str) -> None: else: raise KeyError(key) - @overload - def update(self, other: Dataset) -> None: ... - - @overload - def update(self, other: Mapping[Hashable, DataArray | Variable]) -> None: ... - - @overload - def update(self, other: Mapping[str, DataTree | DataArray | Variable]) -> None: ... - - def _get_target_node(self, node_path: NodePath) -> DataTree: - """Helper function to get node or create if missing.""" + def _get_target_object(self, node_path: NodePath | str) -> DataTree | DataArray: + """Helper function to get object, or create empty node if missing.""" try: - target_node = self._get_item(node_path) + target_object = self._get_item(node_path) except KeyError: - # create new nodes along the path + # create new empty node, and all nodes along the path self._set_item(node_path, DataTree(), new_nodes_along_path=True) - target_node = self._get_item(node_path) - return target_node + target_object = self._get_item(node_path) + return target_object def _update_local_node( self, @@ -1101,11 +1092,24 @@ def _update_local_node( self._replace_node(data, children=merged_children) + @overload + def update(self, other: Dataset) -> None: ... + + @overload + def update( + self, other: Mapping[Hashable, Dataset | DataArray | Variable] + ) -> None: ... + + @overload + def update( + self, other: Mapping[str, DataTree | Dataset | DataArray | Variable] + ) -> None: ... + def update( self, other: ( Dataset - | Mapping[Hashable, DataArray | Dataset | Variable] + | Mapping[Hashable, Dataset | DataArray | Variable] | Mapping[str, DataTree | Dataset | DataArray | Variable] ), ) -> None: @@ -1121,21 +1125,21 @@ def update( return # Otherwise divide other into groups with unique target nodes - items_by_node = {} + items_by_node: dict[NodePath, dict[str, DataTree | DataArray | Variable]] = {} for k, v in other.items(): node_path, object_name = NodePath(k)._get_components() if isinstance(v, Dataset): # If v is a dataset, node_path/object_name should be a node - target_node = self._get_target_node(f"{node_path}/{object_name}") + target_node = self._get_target_object(f"{node_path}/{object_name}") # Update the node immediately target_node._update_local_node(v) else: - # Otherwise add to target nodes items to update later as a group + # Otherwise add to target node's items to update later as a group items_by_node.setdefault(node_path, {}).update({object_name: v}) # Update each target node, creating if necessary for node_path, node_other in items_by_node.items(): - target_node = self._get_target_node(node_path) + target_node = self._get_target_object(node_path) target_node._update_local_node(node_other) def assign( diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index bea2e30be90..f5f6706bf3f 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -42,7 +42,7 @@ def absolute(self) -> Self: """Convert into an absolute path.""" return type(self)("/", *self.parts) - def _get_components(self) -> None: + def _get_components(self) -> tuple[NodePath, str]: """ Check the NodePath has a non-empty name, which is required for object assignment, e.g. tree['/path/to/node/object_name'] = value. diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index b9477643fee..3543483fb64 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -9,7 +9,7 @@ import pytest import xarray as xr -from xarray import DataArray, Dataset +from xarray import DataArray, Dataset, Variable from xarray.core.coordinates import DataTreeCoordinates from xarray.core.datatree import DataTree from xarray.core.treenode import NotFoundInTreeError @@ -345,37 +345,25 @@ def test_update(self) -> None: def test_update_with_paths(self) -> None: """Test that update() handles paths consistently with __setitem__.""" - ds = Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [4, 5, 6])}) - dt1 = DataTree() - # Add a DataArray x to node child - new_content = {"child/x": xr.DataArray(data=[10, 20], dims=("z",))} - # Add a variable y to node child - new_content.update({"child/y": xr.Variable(data=2, dims=())}) - # Create a new node z with data ds and parent step_child - new_content.update({"step_child/z": ds}) + dt1 = xr.DataTree(xr.Dataset(coords={"x": [0, 1, 2]})) + new_content: dict[str, DataTree | Dataset | DataArray | Variable] = {} + # Add a DataArray q to child node p + new_content = {"p/q": xr.DataArray(data=[10, 20, 30], dims=("x",))} + # Add a Variable r to child node p + new_content.update({"p/r": xr.Variable(data=2, dims=())}) + # Create new nodes s, t, and assign a Dataset to t + ds = xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [4, 5, 6])}) + new_content.update({"s/t": ds}) dt1.update(new_content) - dt2 = DataTree() + dt2 = DataTree(xr.Dataset(coords={"x": [0, 1, 2]})) for path, obj in new_content.items(): - # Add the new content with __setitem__ + # Add the new content with __setitem__ one item at a time dt2[path] = obj # Both should produce the same result assert_equal(dt1, dt2) - # Test with multiple path assignments - dt3 = DataTree() - dt3.update( - { - "a/x": xr.Variable(data=1, dims=()), - "a/y": xr.Variable(data=2, dims=()), - "b/z": xr.Variable(data=3, dims=()), - } - ) - assert "x" in dt3["/a"].data_vars - assert "y" in dt3["/a"].data_vars - assert "z" in dt3["/b"].data_vars - def test_update_new_named_dataarray(self) -> None: da = xr.DataArray(name="temp", data=[0, 50]) folder1 = DataTree(name="folder1")