Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ Bug Fixes
- Ensure that ``keep_attrs='drop'`` and ``keep_attrs=False`` remove attrs from result, even when there is
only one xarray object given to ``apply_ufunc`` (:issue:`10982` :pull:`10997`).
By `Julia Signell <https://github.com/jsignell>`_.
- Fix ``DataTree`` bugs related to assigning nodes, variables and coordinates with
path like names (:issue:`9485`, :issue:`9490`, :issue:`9978`).
By `Ewan Short <https://github.com/eshort0401>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
17 changes: 17 additions & 0 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,23 @@ def _drop_coords(self, coord_names):
del self._data._node_coord_variables[name]
del self._data._node_indexes[name]

def __setitem__(self, key: Hashable, value: Any) -> None:
"""Set a coordinate, optionally at a path to a child node."""
from xarray.core.treenode import NodePath

# node_path, coord_name = NodePath(key)._get_components()
# target_node = self._data._get_target_object(node_path)
# target_node.coords[coord_name] = value

# Check if key contains a forward slash (path-like access)
if isinstance(key, str) and "/" in key:
# Parse key as NodePath to enforce correct structure
node_path, coord_name = NodePath(key)._get_components()
target_node = self._data._get_target_object(node_path)
target_node.coords[coord_name] = value
else:
super().__setitem__(key, value)

def __delitem__(self, key: Hashable) -> None:
if key in self:
del self._data[key] # type: ignore[arg-type] # see https://github.com/pydata/xarray/issues/8836
Expand Down
76 changes: 63 additions & 13 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,16 +1040,17 @@ def __delitem__(self, key: str) -> None:
else:
raise KeyError(key)

@overload
def update(self, other: Dataset) -> None: ...

@overload
def update(self, other: Mapping[Hashable, DataArray | Variable]) -> None: ...

@overload
def update(self, other: Mapping[str, DataTree | DataArray | Variable]) -> None: ...

def update(
def _get_target_object(self, node_path: NodePath | str) -> DataTree | DataArray:
"""Helper function to get object, or create empty node if missing."""
try:
target_object = self._get_item(node_path)
except KeyError:
# create new empty node, and all nodes along the path
self._set_item(node_path, DataTree(), new_nodes_along_path=True)
target_object = self._get_item(node_path)
return target_object

def _update_local_node(
self,
other: (
Dataset
Expand All @@ -1058,9 +1059,8 @@ def update(
),
) -> None:
"""
Update this node's children and / or variables.

Just like `dict.update` this is an in-place operation.
Helper to update the node's children and / or variables assuming there are no
path-like keys.
"""
new_children: dict[str, DataTree] = {}
new_variables: CoercibleMapping
Expand Down Expand Up @@ -1092,6 +1092,56 @@ def update(

self._replace_node(data, children=merged_children)

@overload
def update(self, other: Dataset) -> None: ...

@overload
def update(
self, other: Mapping[Hashable, Dataset | DataArray | Variable]
) -> None: ...

@overload
def update(
self, other: Mapping[str, DataTree | Dataset | DataArray | Variable]
) -> None: ...

def update(
self,
other: (
Dataset
| Mapping[Hashable, Dataset | DataArray | Variable]
| Mapping[str, DataTree | Dataset | DataArray | Variable]
),
) -> None:
"""
Update this node's children and / or variables allowing for path-like keys.

Just like `dict.update` this is an in-place operation.
"""

# If other is a dataset assume we want to update just this node
if isinstance(other, Dataset):
self._update_local_node(other)
return

# Otherwise divide other into groups with unique target nodes
items_by_node: dict[NodePath, dict[str, DataTree | DataArray | Variable]] = {}
for k, v in other.items():
node_path, object_name = NodePath(k)._get_components()
if isinstance(v, Dataset):
# If v is a dataset, node_path/object_name should be a node
target_node = self._get_target_object(f"{node_path}/{object_name}")
# Update the node immediately
target_node._update_local_node(v)
else:
# Otherwise add to target node's items to update later as a group
items_by_node.setdefault(node_path, {}).update({object_name: v})

# Update each target node, creating if necessary
for node_path, node_other in items_by_node.items():
target_node = self._get_target_object(node_path)
target_node._update_local_node(node_other)

def assign(
self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any
) -> DataTree:
Expand Down
41 changes: 38 additions & 3 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ def absolute(self) -> Self:
"""Convert into an absolute path."""
return type(self)("/", *self.parts)

def _get_components(self) -> tuple[NodePath, str]:
"""
Check the NodePath has a non-empty name, which is required for object
assignment, e.g. tree['/path/to/node/object_name'] = value.
"""
if not self.name:
raise ValueError(
f"Invalid assignment path {self!r}. Assignment paths should have a "
"structure like 'path/to/node/object_name'."
)
return self.parent, self.name


class TreeNode:
"""
Expand Down Expand Up @@ -172,12 +184,35 @@ def children(self, children: Mapping[str, Self]) -> None:

old_children = self.children
del self.children

def handle_child(name, child):
"""
Decide whether to call _set_item or _set_parent. If the child name is a
path, we call ._set_item to make sure any intermediate nodes are also
created. When name is a path there will be nested calls of children due to
the setter decorator, and the fact ._set_item assigns to the children
attribute.
"""
if "/" in name:
# Path like node name. Create node at appropriate level of nesting.
Comment on lines +188 to +197
Copy link
Member

@TomNicholas TomNicholas Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried that this design change will silently break things somehow. It effectively allows violating an intended design invariant (that node names do not have slashes) then attempts to ensure that it cleans up after itself. I would much prefer to solve this in some way that never violates the invariant, even internally.

self._set_item(
path=name,
item=child,
new_nodes_along_path=True,
allow_overwrite=True,
)
else:
# Simple name, just call _set_parent directly.
child._set_parent(new_parent=self, child_name=name)

try:
self._pre_attach_children(children)
for name, child in children.items():
child._set_parent(new_parent=self, child_name=name)
handle_child(name, child)
self._post_attach_children(children)
assert len(self.children) == len(children)
# Check the children were created in the right place (probably redundant)
for path in children.keys():
self._get_item(path)
except Exception:
# if something goes wrong then revert to previous children
self.children = old_children
Expand Down Expand Up @@ -206,7 +241,7 @@ def _check_children(children: Mapping[str, TreeNode]) -> None:
if not isinstance(child, TreeNode):
raise TypeError(
f"Cannot add object {name}. It is of type {type(child)}, "
"but can only add children of type DataTree"
"but can only add children of type TreeNode"
)

childid = id(child)
Expand Down
50 changes: 49 additions & 1 deletion xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

import xarray as xr
from xarray import DataArray, Dataset
from xarray import DataArray, Dataset, Variable
from xarray.core.coordinates import DataTreeCoordinates
from xarray.core.datatree import DataTree
from xarray.core.treenode import NotFoundInTreeError
Expand Down Expand Up @@ -82,6 +82,12 @@ def __deepcopy__(self, memo):
dt3 = xr.DataTree.from_dict({"/": ds, "child": ds})
assert_identical(dt2, dt3)

def test_pathlike_children(self) -> None:
dt = xr.DataTree(children={"a/b/c": xr.DataTree(), "a/d": xr.DataTree()})
assert "c" in dt["a/b"].children
assert dt["a/b/c"].path == "/a/b/c"
assert "d" in dt["a"].children


class TestFamilyTree:
def test_dont_modify_children_inplace(self) -> None:
Expand Down Expand Up @@ -337,6 +343,27 @@ def test_update(self) -> None:
assert_equal(dt, expected)
assert dt.groups == ("/", "/a")

def test_update_with_paths(self) -> None:
"""Test that update() handles paths consistently with __setitem__."""
dt1 = xr.DataTree(xr.Dataset(coords={"x": [0, 1, 2]}))
new_content: dict[str, DataTree | Dataset | DataArray | Variable] = {}
# Add a DataArray q to child node p
new_content = {"p/q": xr.DataArray(data=[10, 20, 30], dims=("x",))}
# Add a Variable r to child node p
new_content.update({"p/r": xr.Variable(data=2, dims=())})
# Create new nodes s, t, and assign a Dataset to t
ds = xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [4, 5, 6])})
new_content.update({"s/t": ds})
dt1.update(new_content)

dt2 = DataTree(xr.Dataset(coords={"x": [0, 1, 2]}))
for path, obj in new_content.items():
# Add the new content with __setitem__ one item at a time
dt2[path] = obj

# Both should produce the same result
assert_equal(dt1, dt2)

def test_update_new_named_dataarray(self) -> None:
da = xr.DataArray(name="temp", data=[0, 50])
folder1 = DataTree(name="folder1")
Expand Down Expand Up @@ -722,6 +749,27 @@ def test_modify(self) -> None:
dt2 = DataTree(dataset=dt.coords)
assert_identical(dt2.coords, dt.coords)

def test_setitem_coord_on_child_node(self) -> None:
"""Test that coordinates can be assigned to child nodes using path-like syntax."""
tree = DataTree(Dataset(coords={"x": 0}), children={"child": DataTree()})

# Assign coordinate to child node using path
tree.coords["/child/y"] = 2

# Verify coordinate is on child node, not root
assert "y" in tree["/child"].coords
assert "y" not in tree.coords
assert_array_equal(tree["/child"].coords["y"], 2)

# Verify root still has its coordinate
assert "x" in tree.coords
assert_array_equal(tree.coords["x"], 0)

# Test relative path
tree.coords["child/z"] = 3
assert "z" in tree["/child"].coords
assert_array_equal(tree["/child"].coords["z"], 3)

def test_inherited(self) -> None:
ds = Dataset(
data_vars={
Expand Down
Loading