Skip to content

Adding dot for xtensor #1450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: labeled_tensors
Choose a base branch
from

Conversation

AllenDowney
Copy link

@AllenDowney AllenDowney commented Jun 6, 2025

Add dot operation to xtensor module

This PR adds support for the dot product operation in the xtensor module. The implementation includes:

New dot method

  • Added a .dot() method to XTensorVariable in pytensor/xtensor/type.py to provide a consistent interface for dot operations, similar to other math functions.

Rewrite rule for dot

  • Implemented a rewrite rule in pytensor/xtensor/rewriting/math.py that converts the XDot operation to a tensor-based dot operation using tensordot. This rule handles dimension alignment and contraction correctly.

Import of math rewriting module

  • Updated pytensor/xtensor/rewriting/__init__.py to import the math rewriting module, ensuring that the dot rewrite rule is registered and available during the rewrite pass.

Unit tests

  • Added a new test function test_dot() in tests/xtensor/test_math.py to verify the basic functionality of the dot operation, including matrix-matrix and matrix-vector dot products, proper dimension handling, and shape validation.

These changes ensure that the xtensor module now supports dot operations, maintaining consistency with other math functions and enabling proper dimension handling for tensor contractions.


📚 Documentation preview 📚: https://pytensor--1450.org.readthedocs.build/en/1450/

@@ -151,3 +151,34 @@ def test_cast():
yc64 = x.astype("complex64")
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
yc64.astype("float64")


def test_dot():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks promising, but needs some tests calling dot with specific dims?

@ricardoV94 ricardoV94 requested a review from OriolAbril June 6, 2025 21:06
y = as_xtensor(y)

# Validate dimensions if specified
if dims is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks are better placed in make_node

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants