-
Notifications
You must be signed in to change notification settings - Fork 135
Optimize matmuls involving block diagonal matrices #1493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Optimize matmuls involving block diagonal matrices #1493
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces an optimization that rewrites matrix multiplications involving block diagonal matrices into separate smaller multiplications and concatenations, yielding significant performance gains. It also adds tests to verify the rewrite and benchmarks to measure its impact.
- Implement
local_block_diag_dot_to_dot_block_diag
rewrite inmath.py
- Import and wire up necessary primitives (
split
,join
,BlockDiagonal
) - Add unit tests and benchmarks in
test_math.py
to validate correctness and performance
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
pytensor/tensor/rewriting/math.py | Added the local_block_diag_dot_to_dot_block_diag rewrite and required imports (split , join , BlockDiagonal ) |
tests/tensor/rewriting/test_math.py | Added tests (test_local_block_diag_dot_to_dot_block_diag ) and benchmarks (test_block_diag_dot_to_dot_concat_benchmark ) |
Comments suppressed due to low confidence (1)
pytensor/tensor/rewriting/math.py:191
- The name
Blockwise
is referenced but not imported, which will raise aNameError
if the first condition is false. Addfrom pytensor.tensor.slinalg import Blockwise
(or the correct module) at the top of the file.
or isinstance(x.owner.op, Blockwise)
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (86.20%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1493 +/- ##
=======================================
Coverage ? 81.99%
=======================================
Files ? 231
Lines ? 52202
Branches ? 9183
=======================================
Hits ? 42804
Misses ? 7090
Partials ? 2308
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Some minor optimization questions
Description
This PR adds a rewrite to optimize matrix multiplication involving block diagonal matrices. When we have a a matrix
X = BlockDiag(A, B)
, when you doZ = X @ Y
, there's no interaction between terms in theA
part andB
part of theX
matrix. So the dot can be instead computed asrow_stack(A @ Y[:X.shape[0]], B @ Y[X.shape[0]:]
(or in the general case,Y
can be split inton
pieces with appropriate shapes, and dorow_stack([diag_component @ y_split for diag_component, y_split in zip(BlockDiag.inputs, split(Y, *args)])
. If the case where the blockdiag matrix is right-multiplying, you instead col_stack and slice on axis=1.Anyway, it's a lot faster to do this, because matmuls scale really badly in the dimension of the input, so doing two smaller operations is preferred. Here are the benchmarks, small has
n=10
, medium hasn=100
, large hasn=1000
. But in all cases it shows at least 2x speedup.Related Issue
block_diag(a, b) @ c
#1044Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1493.org.readthedocs.build/en/1493/