Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
as_xtensor,
dim,
xtensor,
xtensor_constant,
)
Expand Down
81 changes: 58 additions & 23 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from collections.abc import Sequence

from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op
from pytensor.scalar.basic import uint64
from pytensor.tensor.basic import ones as tensor_ones
from pytensor.tensor.basic import zeros as tensor_zeros
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
from pytensor.xtensor.type import DimVariable, XTensorType, as_xtensor, xtensor


DIM_LENGTH_SCALAR = uint64


class XOp(Op):
Expand Down Expand Up @@ -32,6 +37,7 @@ def make_node(self, x):
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[x] = inputs
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
Expand All @@ -41,46 +47,49 @@ def L_op(self, inputs, outs, g_outs):


class XTensorFromTensor(XTypeCastOp):
__props__ = ("dims",)

def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)
__props__ = ()

def make_node(self, x):
def make_node(self, x, *dims):
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])
output = xtensor(dtype=x.type.dtype, dims=dims)
return Apply(self, [x, *dims], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]


def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)
def xtensor_from_tensor(x, dims, name=None, check: bool = True):
if check:
x = specify_shape(x, [dim.size for dim in dims])
return XTensorFromTensor()(x, *dims, name=name)


class Rename(XTypeCastOp):
__props__ = ("new_dims",)
class MapDims(XTypeCastOp):
__props__ = ("new_dim_indices",)

def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims
def __init__(self, new_dim_indices: tuple[int, ...]):
self.new_dims_indices = new_dim_indices

def make_node(self, x):
def make_node(self, x, *new_dims):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
new_dims = list(x.dims)
for i, idx in enumerate(self.new_dims_indices):
new_dims[idx] = new_dims[i]

output = x.type.clone(dims=new_dims)()
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[x] = inputs
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]
return [map_dims(g_out, dims=x.type.dims)]


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
def map_dims(x, name_dict: dict[DimVariable, DimVariable] | None = None, **names):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
Expand All @@ -97,4 +106,30 @@ def rename(x, name_dict: dict[str, str] | None = None, **names: str):
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)

return Rename(tuple(new_names))(x)
return MapDims(tuple(new_names))(x)


def zeros(*dims, dtype=None, name=None):
"""Create a new XTensor filled with zeros."""
if not dims:
raise ValueError("At least one dimension must be specified")

return xtensor_from_tensor(
tensor_zeros(shape=[dim.size for dim in dims], dtype=dtype),
dims=dims,
name=name,
check=False,
)


def ones(*dims, dtype=None, name=None):
"""Create a new XTensor filled with zeros."""
if not dims:
raise ValueError("At least one dimension must be specified")

return xtensor_from_tensor(
tensor_ones(shape=[dim.size for dim in dims], dtype=dtype),
dims=dims,
name=name,
check=False,
)
173 changes: 173 additions & 0 deletions pytensor/xtensor/dims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from __future__ import annotations

from uuid import uuid4

import numpy as np

from pytensor.graph.basic import Apply
from pytensor.graph.op import Op, Variable
from pytensor.scalar.basic import ScalarVariable
from pytensor.xtensor.type import (
DIM_LENGTH_SCALAR,
BaseDim,
CloneDim,
DimType,
DimVariable,
XTensorVariable,
)


class DimOp(Op):
def perform(self, node, inputs, outputs):
raise NotImplementedError(
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)


# Not a dim op, because it doesn't return a DimVariable
class Length(Op):
__props__ = ()

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, DimVariable):
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
return Apply(self, [x], [DIM_LENGTH_SCALAR()])

def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0]


def _dim_size(dim: DimVariable) -> ScalarVariable:
return Length()(dim)


class FromLength(DimOp):
__props__ = ("dim_type",)

def __init__(self, dim_type: DimType):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(length,) = inputs
if not isinstance(length, ScalarVariable):
raise TypeError(f"length must be a ScalarVariable, got {type(length.type)}")
if length.type != DIM_LENGTH_SCALAR:
raise TypeError(
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
)
return Apply(self, [length], [self.dim_type()])

def perform(self, node, inputs, outputs):
"""Convert the length to a list of lengths."""
outputs[0][0] = inputs[0]


def from_length(length: ScalarVariable, name: str | None = None) -> DimVariable:
# TODO add check for dtype
if not isinstance(length, ScalarVariable):
raise TypeError(f"length must be a ScalarVariable, got {type(length.type)}")
if length.type != DIM_LENGTH_SCALAR:
raise TypeError(
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
)

uuid = uuid4()
dim_type = DimType(dim=BaseDim(uuid=uuid, name=name))
op = FromLength(dim_type)
return op(length, name=name)


class FromTensor(Op):
__props__ = ("dim_type",)

def __init__(self, dim_type: DimType):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, XTensorVariable):
raise TypeError(f"x must be an XTensorVariable, got {type(x.type)}")
return Apply(self, [x], [self.dim_type()])

def perform(self, node, inputs, outputs):
"""Convert the tensor to a dimension variable."""
(x,) = inputs
(x_var,) = node.inputs
for i, dim in enumerate(x_var.type.dims):
if dim == self.dim_type.dim:
outputs[0][0] = x.shape[i]
return
raise ValueError(
f"Dimension {self.dim_type.dim} not found in tensor {x.type.dims}"
)


def _dim_from_tensor(x: XTensorVariable, idx: int) -> DimVariable:
op = FromTensor(dim_type=DimType(x.type.dims[idx]))
return op(x, name=x.type.dims[idx].name)


class Clone(Op):
__props__ = ("dim_type",)

def __init__(self, dim_type):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, DimVariable):
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
return Apply(self, [x], [self.dim_type()])

def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0]


def _clone_dim(dim: DimVariable, *, name: str | None = None) -> DimVariable:
"""Rename a dimension variable.

Args:
name: The new name for the dimension.

Returns:
A new DimVariable with the updated name.
"""
dim_type = DimType(dim=CloneDim(uuid=uuid4(), base=dim.type.dim))
return Clone(dim_type)(dim, name=name)


class Product(Op):
__props__ = ()

def make_node(self, *dims: Variable) -> Apply:
if not all(isinstance(dim, DimVariable) for dim in dims):
raise TypeError("All inputs must be DimVariables.")
out = dim_type()
return Apply(self, list(dims), [out])

def perform(self, node, inputs, outputs):
outputs[0][0] = np.prod(inputs, dtype=DIM_LENGTH_SCALAR.dtype).item()


def product_dim(*dims: DimVariable, name: str | None = None) -> DimVariable:
return Product()(*dims, name=name)


def rebase_dim(dim: DimVariable, *tensors: XTensorVariable) -> DimVariable:
if not isinstance(dim, DimVariable):
raise TypeError(f"dim must be a DimVariable, got {type(dim)}")

if not tensors:
raise ValueError("At least one tensor must be provided for rebasing.")

for tensor in tensors:
for i, tensor_dim in enumerate(tensor.type.dims):
if dim.type.dim == tensor_dim:
return _dim_from_tensor(tensor, idx=i)
raise ValueError(
f"Dimension {dim.type.dim} not found in any of the provided tensors."
)
33 changes: 16 additions & 17 deletions pytensor/xtensor/rewriting/basic.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
from pytensor.graph import node_rewriter
from pytensor.tensor.basic import register_infer_shape
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
from pytensor.xtensor.basic import (
Rename,
MapDims,
TensorFromXTensor,
XTensorFromTensor,
xtensor_from_tensor,
)
from pytensor.xtensor.rewriting.utils import register_lower_xtensor


@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[TensorFromXTensor])
# @register_infer_shape
# @register_useless
# @register_canonicalize
# @register_lower_xtensor
# @node_rewriter(tracks=[TensorFromXTensor])
def useless_tensor_from_xtensor(fgraph, node):
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
[x] = node.inputs
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
return [x.owner.inputs[0]]


@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[XTensorFromTensor])
# @register_infer_shape
# @register_useless
# @register_canonicalize
# @register_lower_xtensor
# @node_rewriter(tracks=[XTensorFromTensor])
def useless_xtensor_from_tensor(fgraph, node):
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
[x] = node.inputs
# TODO
[x, *dims] = node.inputs
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
return [x.owner.inputs[0]]

Expand All @@ -39,13 +38,13 @@ def useless_xtensor_from_tensor(fgraph, node):
def useless_tensor_from_xtensor_of_rename(fgraph, node):
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
[renamed_x] = node.inputs
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
if renamed_x.owner and isinstance(renamed_x.owner.op, MapDims):
[x] = renamed_x.owner.inputs
return node.op(x, return_list=True)


@register_lower_xtensor
@node_rewriter(tracks=[Rename])
@node_rewriter(tracks=[MapDims])
def useless_rename(fgraph, node):
"""

Expand All @@ -54,7 +53,7 @@ def useless_rename(fgraph, node):
"""
[renamed_x] = node.inputs
if renamed_x.owner:
if isinstance(renamed_x.owner.op, Rename):
if isinstance(renamed_x.owner.op, MapDims):
[x] = renamed_x.owner.inputs
return [node.op(x)]
elif isinstance(renamed_x.owner.op, TensorFromXTensor):
Expand Down
Loading
Loading