Skip to content

Commit 67da7eb

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement expand_dims for XTensorVariables (#1449)
1 parent 5f4b4a9 commit 67da7eb

File tree

4 files changed

+307
-3
lines changed

4 files changed

+307
-3
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pytensor.graph import node_rewriter
22
from pytensor.tensor import (
33
broadcast_to,
4+
expand_dims,
45
join,
56
moveaxis,
67
specify_shape,
@@ -10,6 +11,7 @@
1011
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
1112
from pytensor.xtensor.shape import (
1213
Concat,
14+
ExpandDims,
1315
Squeeze,
1416
Stack,
1517
Transpose,
@@ -121,7 +123,7 @@ def lower_transpose(fgraph, node):
121123

122124
@register_lower_xtensor
123125
@node_rewriter([Squeeze])
124-
def local_squeeze_reshape(fgraph, node):
126+
def lower_squeeze(fgraph, node):
125127
"""Rewrite Squeeze to tensor.squeeze."""
126128
[x] = node.inputs
127129
x_tensor = tensor_from_xtensor(x)
@@ -132,3 +134,33 @@ def local_squeeze_reshape(fgraph, node):
132134

133135
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
134136
return [new_out]
137+
138+
139+
@register_lower_xtensor
140+
@node_rewriter([ExpandDims])
141+
def lower_expand_dims(fgraph, node):
142+
"""Rewrite ExpandDims using tensor operations."""
143+
x, size = node.inputs
144+
out = node.outputs[0]
145+
146+
# Convert inputs to tensors
147+
x_tensor = tensor_from_xtensor(x)
148+
size_tensor = tensor_from_xtensor(size)
149+
150+
# Get the new dimension name and position
151+
new_axis = 0 # Always insert at front
152+
153+
# Use tensor operations
154+
if out.type.shape[0] == 1:
155+
# Simple case: just expand with size 1
156+
result_tensor = expand_dims(x_tensor, new_axis)
157+
else:
158+
# Otherwise broadcast to the requested size
159+
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape))
160+
161+
# Preserve static shape information
162+
result_tensor = specify_shape(result_tensor, out.type.shape)
163+
164+
# Convert result back to xtensor
165+
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
166+
return [result]

pytensor/xtensor/shape.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import typing
22
import warnings
3-
from collections.abc import Sequence
3+
from collections.abc import Hashable, Sequence
44
from types import EllipsisType
55
from typing import Literal
66

7+
import numpy as np
8+
79
from pytensor.graph import Apply
810
from pytensor.scalar import discrete_dtypes, upcast
911
from pytensor.tensor import as_tensor, get_scalar_constant_value
1012
from pytensor.tensor.exceptions import NotScalarConstantError
13+
from pytensor.tensor.type import integer_dtypes
1114
from pytensor.xtensor.basic import XOp
1215
from pytensor.xtensor.type import as_xtensor, xtensor
1316

@@ -381,3 +384,121 @@ def squeeze(x, dim=None, drop=False, axis=None):
381384
return x # no-op if nothing to squeeze
382385

383386
return Squeeze(dims=dims)(x)
387+
388+
389+
class ExpandDims(XOp):
390+
"""Add a new dimension to an XTensorVariable."""
391+
392+
__props__ = ("dim",)
393+
394+
def __init__(self, dim):
395+
if not isinstance(dim, str):
396+
raise TypeError(f"`dim` must be a string, got: {type(self.dim)}")
397+
398+
self.dim = dim
399+
400+
def make_node(self, x, size):
401+
x = as_xtensor(x)
402+
403+
if self.dim in x.type.dims:
404+
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")
405+
406+
size = as_xtensor(size, dims=())
407+
if not (size.dtype in integer_dtypes and size.ndim == 0):
408+
raise ValueError(f"size should be an integer scalar, got {size.type}")
409+
try:
410+
static_size = int(get_scalar_constant_value(size))
411+
except NotScalarConstantError:
412+
static_size = None
413+
414+
# If size is a constant, validate it
415+
if static_size is not None and static_size < 0:
416+
raise ValueError(f"size must be 0 or positive, got: {static_size}")
417+
new_shape = (static_size, *x.type.shape)
418+
419+
# Insert new dim at front
420+
new_dims = (self.dim, *x.type.dims)
421+
422+
out = xtensor(
423+
dtype=x.type.dtype,
424+
shape=new_shape,
425+
dims=new_dims,
426+
)
427+
return Apply(self, [x, size], [out])
428+
429+
430+
def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwargs):
431+
"""Add one or more new dimensions to an XTensorVariable."""
432+
x = as_xtensor(x)
433+
434+
# Store original dimensions for axis handling
435+
original_dims = x.type.dims
436+
437+
# Warn if create_index_for_new_dim is used (not supported)
438+
if create_index_for_new_dim is not None:
439+
warnings.warn(
440+
"create_index_for_new_dim=False has no effect in pytensor.xtensor",
441+
UserWarning,
442+
stacklevel=2,
443+
)
444+
445+
if dim is None:
446+
dim = dim_kwargs
447+
elif dim_kwargs:
448+
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")
449+
450+
# Check that dim is Hashable or a sequence of Hashable or dict
451+
if not isinstance(dim, Hashable):
452+
if not isinstance(dim, Sequence | dict):
453+
raise TypeError(f"unhashable type: {type(dim).__name__}")
454+
if not all(isinstance(d, Hashable) for d in dim):
455+
raise TypeError(f"unhashable type in {type(dim).__name__}")
456+
457+
# Normalize to a dimension-size mapping
458+
if isinstance(dim, str):
459+
dims_dict = {dim: 1}
460+
elif isinstance(dim, Sequence) and not isinstance(dim, dict):
461+
dims_dict = {d: 1 for d in dim}
462+
elif isinstance(dim, dict):
463+
dims_dict = {}
464+
for name, val in dim.items():
465+
if isinstance(val, str):
466+
raise TypeError(f"Dimension size cannot be a string: {val}")
467+
if isinstance(val, Sequence | np.ndarray):
468+
warnings.warn(
469+
"When a sequence is provided as a dimension size, only its length is used. "
470+
"The actual values (which would be coordinates in xarray) are ignored.",
471+
UserWarning,
472+
stacklevel=2,
473+
)
474+
dims_dict[name] = len(val)
475+
else:
476+
# should be int or symbolic scalar
477+
dims_dict[name] = val
478+
else:
479+
raise TypeError(f"Invalid type for `dim`: {type(dim)}")
480+
481+
# Insert each new dim at the front (reverse order preserves user intent)
482+
for name, size in reversed(dims_dict.items()):
483+
x = ExpandDims(dim=name)(x, size)
484+
485+
# If axis is specified, transpose to put new dimensions in the right place
486+
if axis is not None:
487+
# Wrap non-sequence axis in a list
488+
if not isinstance(axis, Sequence):
489+
axis = [axis]
490+
491+
# require len(axis) == len(dims_dict)
492+
if len(axis) != len(dims_dict):
493+
raise ValueError("lengths of dim and axis should be identical.")
494+
495+
# Insert new dimensions at their specified positions
496+
target_dims = list(original_dims)
497+
for name, pos in zip(dims_dict, axis):
498+
# Convert negative axis to positive position relative to current dims
499+
if pos < 0:
500+
pos = len(target_dims) + pos + 1
501+
target_dims.insert(pos, name)
502+
x = Transpose(dims=tuple(target_dims))(x)
503+
504+
return x

pytensor/xtensor/type.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,47 @@ def squeeze(
573573
"""
574574
return px.shape.squeeze(self, dim, drop, axis)
575575

576+
def expand_dims(
577+
self,
578+
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
579+
create_index_for_new_dim: bool = True,
580+
axis: int | Sequence[int] | None = None,
581+
**dim_kwargs,
582+
):
583+
"""Add one or more new dimensions to the tensor.
584+
585+
Parameters
586+
----------
587+
dim : str | Sequence[str] | dict[str, int | Sequence] | None
588+
If str or sequence of str, new dimensions with size 1.
589+
If dict, keys are dimension names and values are either:
590+
- int: the new size
591+
- sequence: coordinates (length determines size)
592+
create_index_for_new_dim : bool, default: True
593+
Currently ignored. Reserved for future coordinate support.
594+
In xarray, when True (default), creates a coordinate index for the new dimension
595+
with values from 0 to size-1. When False, no coordinate index is created.
596+
axis : int | Sequence[int] | None, default: None
597+
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
598+
By default (None), new dimensions are inserted at the beginning (axis=0).
599+
Symbolic axis is not supported yet.
600+
Negative values count from the end.
601+
**dim_kwargs : int | Sequence
602+
Alternative to `dim` dict. Only used if `dim` is None.
603+
604+
Returns
605+
-------
606+
XTensorVariable
607+
A tensor with additional dimensions inserted at the front.
608+
"""
609+
return px.shape.expand_dims(
610+
self,
611+
dim,
612+
create_index_for_new_dim=create_index_for_new_dim,
613+
axis=axis,
614+
**dim_kwargs,
615+
)
616+
576617
# ndarray methods
577618
# https://docs.xarray.dev/en/latest/api.html#id7
578619
def clip(self, min, max):

tests/xtensor/test_shape.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11-
import pytest
1211
from xarray import DataArray
1312
from xarray import concat as xr_concat
1413

14+
from pytensor.tensor import scalar
1515
from pytensor.xtensor.shape import (
1616
concat,
1717
stack,
@@ -356,3 +356,113 @@ def test_squeeze_errors():
356356
fn2 = xr_function([x2], y2)
357357
with pytest.raises(Exception):
358358
fn2(x2_test)
359+
360+
361+
def test_expand_dims():
362+
"""Test expand_dims."""
363+
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
364+
x_test = xr_arange_like(x)
365+
366+
# Implicit size 1
367+
y = x.expand_dims("country")
368+
fn = xr_function([x], y)
369+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))
370+
371+
# Test with multiple dimensions
372+
y = x.expand_dims(["country", "state"])
373+
fn = xr_function([x], y)
374+
xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"]))
375+
376+
# Test with a dict of name-size pairs
377+
y = x.expand_dims({"country": 2, "state": 3})
378+
fn = xr_function([x], y)
379+
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3}))
380+
381+
# Test with kwargs (equivalent to dict)
382+
y = x.expand_dims(country=2, state=3)
383+
fn = xr_function([x], y)
384+
xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3))
385+
386+
# Test with a dict of name-coord array pairs
387+
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
388+
fn = xr_function([x], y)
389+
xr_assert_allclose(
390+
fn(x_test),
391+
x_test.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}),
392+
)
393+
394+
# Symbolic size 1
395+
size_sym_1 = scalar("size_sym_1", dtype="int64")
396+
y = x.expand_dims({"country": size_sym_1})
397+
fn = xr_function([x, size_sym_1], y)
398+
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1}))
399+
400+
# Test with symbolic sizes in dict
401+
size_sym_2 = scalar("size_sym_2", dtype="int64")
402+
y = x.expand_dims({"country": size_sym_1, "state": size_sym_2})
403+
fn = xr_function([x, size_sym_1, size_sym_2], y)
404+
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
405+
406+
# Test with symbolic sizes in kwargs
407+
y = x.expand_dims(country=size_sym_1, state=size_sym_2)
408+
fn = xr_function([x, size_sym_1, size_sym_2], y)
409+
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
410+
411+
# Test with axis parameter
412+
y = x.expand_dims("country", axis=1)
413+
fn = xr_function([x], y)
414+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1))
415+
416+
# Test with negative axis parameter
417+
y = x.expand_dims("country", axis=-1)
418+
fn = xr_function([x], y)
419+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1))
420+
421+
# Add two new dims with axis parameters
422+
y = x.expand_dims(["country", "state"], axis=[1, 2])
423+
fn = xr_function([x], y)
424+
xr_assert_allclose(
425+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2])
426+
)
427+
428+
# Add two dims with negative axis parameters
429+
y = x.expand_dims(["country", "state"], axis=[-1, -2])
430+
fn = xr_function([x], y)
431+
xr_assert_allclose(
432+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-1, -2])
433+
)
434+
435+
# Add two dims with positive and negative axis parameters
436+
y = x.expand_dims(["country", "state"], axis=[-2, 1])
437+
fn = xr_function([x], y)
438+
xr_assert_allclose(
439+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[-2, 1])
440+
)
441+
442+
443+
def test_expand_dims_errors():
444+
"""Test error handling in expand_dims."""
445+
446+
# Expanding existing dim
447+
x = xtensor("x", dims=("city",), shape=(3,))
448+
y = x.expand_dims("country")
449+
with pytest.raises(ValueError, match="already exists"):
450+
y.expand_dims("city")
451+
452+
# Invalid dim type
453+
with pytest.raises(TypeError, match="Invalid type for `dim`"):
454+
x.expand_dims(123)
455+
456+
# Duplicate dimension creation
457+
y = x.expand_dims("new")
458+
with pytest.raises(ValueError, match="already exists"):
459+
y.expand_dims("new")
460+
461+
# Find out what xarray does with a numpy array as dim
462+
# x_test = xr_arange_like(x)
463+
# x_test.expand_dims(np.array([1, 2]))
464+
# TypeError: unhashable type: 'numpy.ndarray'
465+
466+
# Test with a numpy array as dim (not supported)
467+
with pytest.raises(TypeError, match="unhashable type"):
468+
y.expand_dims(np.array([1, 2]))

0 commit comments

Comments
 (0)