Skip to content

Commit f134792

Browse files
authored
Implement expand_dims for XTensorVariables (#1449)
1 parent 7b8877b commit f134792

File tree

4 files changed

+307
-3
lines changed

4 files changed

+307
-3
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pytensor.graph import node_rewriter
22
from pytensor.tensor import (
33
broadcast_to,
4+
expand_dims,
45
join,
56
moveaxis,
67
specify_shape,
@@ -10,6 +11,7 @@
1011
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
1112
from pytensor.xtensor.shape import (
1213
Concat,
14+
ExpandDims,
1315
Squeeze,
1416
Stack,
1517
Transpose,
@@ -121,7 +123,7 @@ def lower_transpose(fgraph, node):
121123

122124
@register_lower_xtensor
123125
@node_rewriter([Squeeze])
124-
def local_squeeze_reshape(fgraph, node):
126+
def lower_squeeze(fgraph, node):
125127
"""Rewrite Squeeze to tensor.squeeze."""
126128
[x] = node.inputs
127129
x_tensor = tensor_from_xtensor(x)
@@ -132,3 +134,33 @@ def local_squeeze_reshape(fgraph, node):
132134

133135
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
134136
return [new_out]
137+
138+
139+
@register_lower_xtensor
140+
@node_rewriter([ExpandDims])
141+
def lower_expand_dims(fgraph, node):
142+
"""Rewrite ExpandDims using tensor operations."""
143+
x, size = node.inputs
144+
out = node.outputs[0]
145+
146+
# Convert inputs to tensors
147+
x_tensor = tensor_from_xtensor(x)
148+
size_tensor = tensor_from_xtensor(size)
149+
150+
# Get the new dimension name and position
151+
new_axis = 0 # Always insert at front
152+
153+
# Use tensor operations
154+
if out.type.shape[0] == 1:
155+
# Simple case: just expand with size 1
156+
result_tensor = expand_dims(x_tensor, new_axis)
157+
else:
158+
# Otherwise broadcast to the requested size
159+
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape))
160+
161+
# Preserve static shape information
162+
result_tensor = specify_shape(result_tensor, out.type.shape)
163+
164+
# Convert result back to xtensor
165+
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
166+
return [result]

pytensor/xtensor/shape.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import warnings
2-
from collections.abc import Sequence
2+
from collections.abc import Hashable, Sequence
33
from types import EllipsisType
44
from typing import Literal
55

6+
import numpy as np
7+
68
from pytensor.graph import Apply
79
from pytensor.scalar import discrete_dtypes, upcast
810
from pytensor.tensor import as_tensor, get_scalar_constant_value
911
from pytensor.tensor.exceptions import NotScalarConstantError
12+
from pytensor.tensor.type import integer_dtypes
1013
from pytensor.xtensor.basic import XOp
1114
from pytensor.xtensor.type import as_xtensor, xtensor
1215

@@ -380,3 +383,121 @@ def squeeze(x, dim=None):
380383
return x # no-op if nothing to squeeze
381384

382385
return Squeeze(dims=dims)(x)
386+
387+
388+
class ExpandDims(XOp):
389+
"""Add a new dimension to an XTensorVariable."""
390+
391+
__props__ = ("dim",)
392+
393+
def __init__(self, dim):
394+
if not isinstance(dim, str):
395+
raise TypeError(f"`dim` must be a string, got: {type(self.dim)}")
396+
397+
self.dim = dim
398+
399+
def make_node(self, x, size):
400+
x = as_xtensor(x)
401+
402+
if self.dim in x.type.dims:
403+
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")
404+
405+
size = as_xtensor(size, dims=())
406+
if not (size.dtype in integer_dtypes and size.ndim == 0):
407+
raise ValueError(f"size should be an integer scalar, got {size.type}")
408+
try:
409+
static_size = int(get_scalar_constant_value(size))
410+
except NotScalarConstantError:
411+
static_size = None
412+
413+
# If size is a constant, validate it
414+
if static_size is not None and static_size < 0:
415+
raise ValueError(f"size must be 0 or positive, got: {static_size}")
416+
new_shape = (static_size, *x.type.shape)
417+
418+
# Insert new dim at front
419+
new_dims = (self.dim, *x.type.dims)
420+
421+
out = xtensor(
422+
dtype=x.type.dtype,
423+
shape=new_shape,
424+
dims=new_dims,
425+
)
426+
return Apply(self, [x, size], [out])
427+
428+
429+
def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwargs):
430+
"""Add one or more new dimensions to an XTensorVariable."""
431+
x = as_xtensor(x)
432+
433+
# Store original dimensions for axis handling
434+
original_dims = x.type.dims
435+
436+
# Warn if create_index_for_new_dim is used (not supported)
437+
if create_index_for_new_dim is not None:
438+
warnings.warn(
439+
"create_index_for_new_dim=False has no effect in pytensor.xtensor",
440+
UserWarning,
441+
stacklevel=2,
442+
)
443+
444+
if dim is None:
445+
dim = dim_kwargs
446+
elif dim_kwargs:
447+
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")
448+
449+
# Check that dim is Hashable or a sequence of Hashable or dict
450+
if not isinstance(dim, Hashable):
451+
if not isinstance(dim, Sequence | dict):
452+
raise TypeError(f"unhashable type: {type(dim).__name__}")
453+
if not all(isinstance(d, Hashable) for d in dim):
454+
raise TypeError(f"unhashable type in {type(dim).__name__}")
455+
456+
# Normalize to a dimension-size mapping
457+
if isinstance(dim, str):
458+
dims_dict = {dim: 1}
459+
elif isinstance(dim, Sequence) and not isinstance(dim, dict):
460+
dims_dict = {d: 1 for d in dim}
461+
elif isinstance(dim, dict):
462+
dims_dict = {}
463+
for name, val in dim.items():
464+
if isinstance(val, str):
465+
raise TypeError(f"Dimension size cannot be a string: {val}")
466+
if isinstance(val, Sequence | np.ndarray):
467+
warnings.warn(
468+
"When a sequence is provided as a dimension size, only its length is used. "
469+
"The actual values (which would be coordinates in xarray) are ignored.",
470+
UserWarning,
471+
stacklevel=2,
472+
)
473+
dims_dict[name] = len(val)
474+
else:
475+
# should be int or symbolic scalar
476+
dims_dict[name] = val
477+
else:
478+
raise TypeError(f"Invalid type for `dim`: {type(dim)}")
479+
480+
# Insert each new dim at the front (reverse order preserves user intent)
481+
for name, size in reversed(dims_dict.items()):
482+
x = ExpandDims(dim=name)(x, size)
483+
484+
# If axis is specified, transpose to put new dimensions in the right place
485+
if axis is not None:
486+
# Wrap non-sequence axis in a list
487+
if not isinstance(axis, Sequence):
488+
axis = [axis]
489+
490+
# require len(axis) == len(dims_dict)
491+
if len(axis) != len(dims_dict):
492+
raise ValueError("lengths of dim and axis should be identical.")
493+
494+
# Insert new dimensions at their specified positions
495+
target_dims = list(original_dims)
496+
for name, pos in zip(dims_dict, axis):
497+
# Convert negative axis to positive position relative to current dims
498+
if pos < 0:
499+
pos = len(target_dims) + pos + 1
500+
target_dims.insert(pos, name)
501+
x = Transpose(dims=tuple(target_dims))(x)
502+
503+
return x

pytensor/xtensor/type.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,47 @@ def squeeze(
481481
raise NotImplementedError("Squeeze with axis not Implemented")
482482
return px.shape.squeeze(self, dim)
483483

484+
def expand_dims(
485+
self,
486+
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
487+
create_index_for_new_dim: bool = True,
488+
axis: int | Sequence[int] | None = None,
489+
**dim_kwargs,
490+
):
491+
"""Add one or more new dimensions to the tensor.
492+
493+
Parameters
494+
----------
495+
dim : str | Sequence[str] | dict[str, int | Sequence] | None
496+
If str or sequence of str, new dimensions with size 1.
497+
If dict, keys are dimension names and values are either:
498+
- int: the new size
499+
- sequence: coordinates (length determines size)
500+
create_index_for_new_dim : bool, default: True
501+
Currently ignored. Reserved for future coordinate support.
502+
In xarray, when True (default), creates a coordinate index for the new dimension
503+
with values from 0 to size-1. When False, no coordinate index is created.
504+
axis : int | Sequence[int] | None, default: None
505+
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
506+
By default (None), new dimensions are inserted at the beginning (axis=0).
507+
Symbolic axis is not supported yet.
508+
Negative values count from the end.
509+
**dim_kwargs : int | Sequence
510+
Alternative to `dim` dict. Only used if `dim` is None.
511+
512+
Returns
513+
-------
514+
XTensorVariable
515+
A tensor with additional dimensions inserted at the front.
516+
"""
517+
return px.shape.expand_dims(
518+
self,
519+
dim,
520+
create_index_for_new_dim=create_index_for_new_dim,
521+
axis=axis,
522+
**dim_kwargs,
523+
)
524+
484525
# ndarray methods
485526
# https://docs.xarray.dev/en/latest/api.html#id7
486527
def clip(self, min, max):

tests/xtensor/test_shape.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11-
import pytest
1211
from xarray import DataArray
1312
from xarray import concat as xr_concat
1413

14+
from pytensor.tensor import scalar
1515
from pytensor.xtensor.shape import (
1616
concat,
1717
squeeze,
@@ -369,3 +369,113 @@ def test_squeeze_errors():
369369
fn2 = xr_function([x2], y2)
370370
with pytest.raises(Exception):
371371
fn2(x2_test)
372+
373+
374+
def test_expand_dims():
375+
"""Test expand_dims."""
376+
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
377+
x_test = xr_arange_like(x)
378+
379+
# Implicit size 1
380+
y = x.expand_dims("country")
381+
fn = xr_function([x], y)
382+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))
383+
384+
# Test with multiple dimensions
385+
y = x.expand_dims(["country", "state"])
386+
fn = xr_function([x], y)
387+
xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"]))
388+
389+
# Test with a dict of name-size pairs
390+
y = x.expand_dims({"country": 2, "state": 3})
391+
fn = xr_function([x], y)
392+
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3}))
393+
394+
# Test with kwargs (equivalent to dict)
395+
y = x.expand_dims(country=2, state=3)
396+
fn = xr_function([x], y)
397+
xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3))
398+
399+
# Test with a dict of name-coord array pairs
400+
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
401+
fn = xr_function([x], y)
402+
xr_assert_allclose(
403+
fn(x_test),
404+
x_test.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}),
405+
)
406+
407+
# Symbolic size 1
408+
size_sym_1 = scalar("size_sym_1", dtype="int64")
409+
y = x.expand_dims({"country": size_sym_1})
410+
fn = xr_function([x, size_sym_1], y)
411+
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1}))
412+
413+
# Test with symbolic sizes in dict
414+
size_sym_2 = scalar("size_sym_2", dtype="int64")
415+
y = x.expand_dims({"country": size_sym_1, "state": size_sym_2})
416+
fn = xr_function([x, size_sym_1, size_sym_2], y)
417+
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
418+
419+
# Test with symbolic sizes in kwargs
420+
y = x.expand_dims(country=size_sym_1, state=size_sym_2)
421+
fn = xr_function([x, size_sym_1, size_sym_2], y)
422+
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
423+
424+
# Test with axis parameter
425+
y = x.expand_dims("country", axis=1)
426+
fn = xr_function([x], y)
427+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1))
428+
429+
# Test with negative axis parameter
430+
y = x.expand_dims("country", axis=-1)
431+
fn = xr_function([x], y)
432+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1))
433+
434+
# Add two new dims with axis parameters
435+
y = x.expand_dims(["country", "state"], axis=[1, 2])
436+
fn = xr_function([x], y)
437+
xr_assert_allclose(
438+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2])
439+
)
440+
441+
# Add two dims with negative axis parameters
442+
y = x.expand_dims(["country", "state"], axis=[-1, -2])
443+
fn = xr_function([x], y)
444+
xr_assert_allclose(
445+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-1, -2])
446+
)
447+
448+
# Add two dims with positive and negative axis parameters
449+
y = x.expand_dims(["country", "state"], axis=[-2, 1])
450+
fn = xr_function([x], y)
451+
xr_assert_allclose(
452+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-2, 1])
453+
)
454+
455+
456+
def test_expand_dims_errors():
457+
"""Test error handling in expand_dims."""
458+
459+
# Expanding existing dim
460+
x = xtensor("x", dims=("city",), shape=(3,))
461+
y = x.expand_dims("country")
462+
with pytest.raises(ValueError, match="already exists"):
463+
y.expand_dims("city")
464+
465+
# Invalid dim type
466+
with pytest.raises(TypeError, match="Invalid type for `dim`"):
467+
x.expand_dims(123)
468+
469+
# Duplicate dimension creation
470+
y = x.expand_dims("new")
471+
with pytest.raises(ValueError, match="already exists"):
472+
y.expand_dims("new")
473+
474+
# Find out what xarray does with a numpy array as dim
475+
# x_test = xr_arange_like(x)
476+
# x_test.expand_dims(np.array([1, 2]))
477+
# TypeError: unhashable type: 'numpy.ndarray'
478+
479+
# Test with a numpy array as dim (not supported)
480+
with pytest.raises(TypeError, match="unhashable type"):
481+
y.expand_dims(np.array([1, 2]))

0 commit comments

Comments
 (0)