-
Notifications
You must be signed in to change notification settings - Fork 129
Speedup AdvancedSubtensor1 and AdvancedIncSubtensor1 in C backend #1346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c0045c3
to
c211405
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (87.65%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1346 +/- ##
=======================================
Coverage 82.01% 82.02%
=======================================
Files 207 207
Lines 49250 49302 +52
Branches 8734 8748 +14
=======================================
+ Hits 40394 40438 +44
- Misses 6692 6697 +5
- Partials 2164 2167 +3
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (3)
tests/tensor/test_subtensor.py:1277
- The generator expression used for 'inc_var_static_shape' should be converted to a tuple (e.g. tuple(...)) to ensure it produces a concrete shape tuple.
inc_var_static_shape = (1 if dim_length == 1 else None for dim_length in inc_shape)
pytensor/link/pytorch/dispatch/subtensor.py:112
- The _check_runtime_broadcasting method expects four arguments (including the node), so the node argument should be passed to ensure correct runtime checking.
if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(x, y, indices)
pytensor/link/jax/dispatch/subtensor.py:70
- Similar to the PyTorch dispatch, the node argument is missing when calling _check_runtime_broadcasting. Ensure that the node is passed as the first parameter after self.
if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(x, y, indices)
Useless AI strikes again! |
It was actually correct, I'm surprised no tests failed, I guess we are not really covering the dispatch of these in JAX/PyTorch because the Op is only introduced during rewrites |
c211405
to
38f9036
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
pytensor/tensor/subtensor.py:2247
- The condition validating index values in _idx_may_be_invalid is non-obvious; please verify that it correctly handles negative indices and reflects the intended bounds check.
return not (min_idx >= 0 or min_idx >= -shape0) and (max_idx < 0 or max_idx < shape0)
eef5a69
to
83506d9
Compare
There were some wrong promised types in the LUFactor/Pivots that showed up in this PR @jessegrabowski (first commit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The requested green check, with some questions/comments.
As usual, my ability to contribute on these C-code PRs is limited, but I tried my best.
@@ -83,7 +83,7 @@ def cholesky(a): | |||
@numba_funcify.register(PivotToPermutations) | |||
def pivot_to_permutation(op, node, **kwargs): | |||
inverse = op.inverse | |||
dtype = node.inputs[0].dtype | |||
dtype = node.outputs[0].dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What a dumb mistake, I wonder which idiot wrote this code
pytensor/tensor/basic.py
Outdated
@@ -1659,6 +1659,11 @@ def c_code(self, node, name, inp, out, sub): | |||
o_static_shape = node.outputs[0].type.shape | |||
v_ndim = len(v_static_shape) | |||
o_ndim = len(o_static_shape) | |||
is_zero = ( | |||
all(node.inputs[0].type.broadcastable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this catch a scalar zero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, all(empy) is True
@@ -604,7 +604,7 @@ def make_node(self, pivots): | |||
|
|||
def perform(self, node, inputs, outputs): | |||
[pivots] = inputs | |||
p_inv = np.arange(len(pivots), dtype=pivots.dtype) | |||
p_inv = np.arange(len(pivots), dtype="int64") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to always be careful about single vs double precision on floats (use floatX) but not int -- why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float32 were very geared towards GPU concerns, integers are not as problematic [citation needed]
@@ -639,7 +639,7 @@ def make_node(self, A): | |||
) | |||
|
|||
LU = matrix(shape=A.type.shape, dtype=A.type.dtype) | |||
pivots = vector(shape=(A.type.shape[0],), dtype="int64") | |||
pivots = vector(shape=(A.type.shape[0],), dtype="int32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like why 64 above but 32 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what the scipy function returns. The reason I went with i64 above is that np.argsort
returns that regardless of the input, so I kept the outputs the same type regardless of return_inverse
or whatever the property on that other Op is
83506d9
to
78f1cf8
Compare
Also add checks for runtime broadcast
78f1cf8
to
a0abe86
Compare
These are some of the biggest drags in the C-backend. This PR does some tweaks that increase performance substantically.
AdvancedSubtensor1 benchmark
Before
After
AdvancedIncSubtensor1 benchmark
Before
After
I added a long-missing check for runtime broadcasting in the python/C/torch implementations (would require a bit more code-changes for numba), which moves towards #1348
Provides a restricted case of #1325
Alloc of zeros is also about twice as fast now, which is benchmarked indirectly in the
AdvancedIncSubtensor1
tests📚 Documentation preview 📚: https://pytensor--1346.org.readthedocs.build/en/1346/