Skip to content

Adding expand_dims for xtensor #1449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from pytensor.graph import node_rewriter
from pytensor.raise_op import Assert
from pytensor.tensor import (
broadcast_to,
get_scalar_constant_value,
gt,
join,
moveaxis,
specify_shape,
squeeze,
)
from pytensor.tensor import (
shape as tensor_shape,
)
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import (
Concat,
ExpandDims,
Squeeze,
Stack,
Transpose,
Expand Down Expand Up @@ -132,3 +139,32 @@ def local_squeeze_reshape(fgraph, node):

new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter([ExpandDims])
def local_expand_dims_reshape(fgraph, node):
"""Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify shape."""
x, size = node.inputs
out = node.outputs[0]
# Lower to tensor.expand_dims(x, axis=0)
from pytensor.tensor import expand_dims as tensor_expand_dims

expanded = tensor_expand_dims(tensor_from_xtensor(x), 0)
# Optionally broadcast to the correct shape if size is not 1
from pytensor.tensor import broadcast_to

# Ensure size is positive
expanded = Assert(msg="size must be positive")(expanded, gt(size, 0))
# If size is not 1, broadcast
try:
static_size = get_scalar_constant_value(size)
except Exception:
static_size = None
if static_size is not None and static_size == 1:
result = expanded
else:
# Broadcast to (size, ...)
new_shape = (size,) + tuple(tensor_shape(expanded))[1:]
result = broadcast_to(expanded, new_shape)
return [xtensor_from_tensor(result, out.type.dims)]
115 changes: 115 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from types import EllipsisType
from typing import Literal

import numpy as np

from pytensor.graph import Apply
from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor

Expand Down Expand Up @@ -380,3 +383,115 @@ def squeeze(x, dim=None):
return x # no-op if nothing to squeeze

return Squeeze(dims=dims)(x)


class ExpandDims(XOp):
"""Add a new dimension to an XTensorVariable."""

__props__ = ("dim",)

def __init__(self, dim):
self.dim = dim

def make_node(self, x, size):
x = as_xtensor(x)

if not isinstance(self.dim, str):
raise TypeError(f"`dim` must be a string or None, got: {type(self.dim)}")

if self.dim in x.type.dims:
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")
if isinstance(size, int | np.integer):
if size <= 0:
raise ValueError(f"size must be positive, got: {size}")
elif not (
hasattr(size, "ndim")
and getattr(size, "ndim", None) == 0 # symbolic scalar
):
raise TypeError(
f"size must be an int or scalar variable, got: {type(size)}"
)

# Convert size to tensor
size = as_tensor(size, ndim=0)

# Insert new dim at front
new_dims = (self.dim, *x.type.dims)

# Determine shape
try:
static_size = get_scalar_constant_value(size)
except NotScalarConstantError:
static_size = None
if static_size is not None:
new_shape = (int(static_size), *x.type.shape)
else:
new_shape = (None, *x.type.shape) # symbolic size

out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x, size], [out])


def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable.

Parameters
----------
x : XTensorVariable
Input tensor.
dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
(Ignored for now) Matches xarray API, reserved for future use.
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.

Returns
-------
XTensorVariable
A tensor with additional dimensions inserted at the front.
"""
x = as_xtensor(x)

# Extract size from dim_kwargs if present
size = dim_kwargs.pop("size", 1) if dim_kwargs else 1

if dim is None:
dim = dim_kwargs
elif dim_kwargs:
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")

# Normalize to a dimension-size mapping
if isinstance(dim, str):
dims_dict = {dim: size}
elif isinstance(dim, Sequence) and not isinstance(dim, dict):
dims_dict = {d: 1 for d in dim}
elif isinstance(dim, dict):
dims_dict = {}
for name, val in dim.items():
if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str):
dims_dict[name] = len(val)
elif isinstance(val, int):
dims_dict[name] = val
else:
dims_dict[name] = val # symbolic/int scalar allowed
else:
raise TypeError(f"Invalid type for `dim`: {type(dim)}")

# Convert to canonical form: list of (dim_name, size)
canonical_dims: list[tuple[str, int | np.integer | TensorVariable]] = []
for name, size in dims_dict.items():
canonical_dims.append((name, size))

# Insert each new dim at the front (reverse order preserves user intent)
for name, size in reversed(canonical_dims):
x = ExpandDims(dim=name)(x, size)

return x
21 changes: 21 additions & 0 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,27 @@ def squeeze(
raise NotImplementedError("Squeeze with axis not Implemented")
return px.shape.squeeze(self, dim)

def expand_dims(
self,
dim: str | None = None,
size: int | Variable = 1,
):
"""Add a new dimension to the tensor.

Parameters
----------
dim : str or None
Name of new dimension. If None, returns self unchanged.
size : int or symbolic, optional
Size of the new dimension (default 1)

Returns
-------
XTensorVariable
Tensor with the new dimension inserted
"""
return px.shape.expand_dims(self, dim, size=size)

# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
Expand Down
Loading