Skip to content

Commit e7b7eaf

Browse files
committed
Specialize AdvancedSubtensor1 mode for compile time valid indices
1 parent 3c66aa6 commit e7b7eaf

File tree

2 files changed

+89
-27
lines changed

2 files changed

+89
-27
lines changed

pytensor/tensor/subtensor.py

+64-27
Original file line numberDiff line numberDiff line change
@@ -2120,16 +2120,12 @@ def make_node(self, x, ilist):
21202120
out_shape = (ilist_.type.shape[0], *x_.type.shape[1:])
21212121
return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=out_shape)()])
21222122

2123-
def perform(self, node, inp, out_):
2123+
def perform(self, node, inp, output_storage):
21242124
x, i = inp
2125-
(out,) = out_
2126-
# Copy always implied by numpy advanced indexing semantic.
2127-
if out[0] is not None and out[0].shape == (len(i),) + x.shape[1:]:
2128-
o = out[0]
2129-
else:
2130-
o = None
21312125

2132-
out[0] = x.take(i, axis=0, out=o)
2126+
# Numpy take is always slower when out is provided
2127+
# https://github.com/numpy/numpy/issues/28636
2128+
output_storage[0][0] = x.take(i, axis=0, out=None)
21332129

21342130
def connection_pattern(self, node):
21352131
rval = [[True], *([False] for _ in node.inputs[1:])]
@@ -2174,42 +2170,83 @@ def c_code(self, node, name, input_names, output_names, sub):
21742170
"c_code defined for AdvancedSubtensor1, not for child class",
21752171
type(self),
21762172
)
2173+
x, idxs = node.inputs
2174+
if self._idx_may_be_invalid(x, idxs):
2175+
mode = "NPY_RAISE"
2176+
else:
2177+
# We can know ahead of time that all indices are valid, so we can use a faster mode
2178+
mode = "NPY_WRAP" # This seems to be faster than NPY_CLIP
2179+
21772180
a_name, i_name = input_names[0], input_names[1]
21782181
output_name = output_names[0]
21792182
fail = sub["fail"]
2180-
return f"""
2181-
if ({output_name} != NULL) {{
2182-
npy_intp nd, i, *shape;
2183-
nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
2184-
if (PyArray_NDIM({output_name}) != nd) {{
2183+
if mode == "NPY_RAISE":
2184+
# numpy_take always makes an intermediate copy if NPY_RAISE which is slower than just allocating a new buffer
2185+
# We can remove this special case after https://github.com/numpy/numpy/issues/28636
2186+
manage_pre_allocated_out = f"""
2187+
if ({output_name} != NULL) {{
2188+
// Numpy TakeFrom is always slower when copying
2189+
// https://github.com/numpy/numpy/issues/28636
21852190
Py_CLEAR({output_name});
21862191
}}
2187-
else {{
2188-
shape = PyArray_DIMS({output_name});
2189-
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
2190-
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
2191-
Py_CLEAR({output_name});
2192-
break;
2193-
}}
2192+
"""
2193+
else:
2194+
manage_pre_allocated_out = f"""
2195+
if ({output_name} != NULL) {{
2196+
npy_intp nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
2197+
if (PyArray_NDIM({output_name}) != nd) {{
2198+
Py_CLEAR({output_name});
21942199
}}
2195-
if ({output_name} != NULL) {{
2196-
for (; i < nd; i++) {{
2197-
if (shape[i] != PyArray_DIMS({a_name})[
2198-
i-PyArray_NDIM({i_name})+1]) {{
2200+
else {{
2201+
int i;
2202+
npy_intp* shape = PyArray_DIMS({output_name});
2203+
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
2204+
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
21992205
Py_CLEAR({output_name});
22002206
break;
22012207
}}
22022208
}}
2209+
if ({output_name} != NULL) {{
2210+
for (; i < nd; i++) {{
2211+
if (shape[i] != PyArray_DIMS({a_name})[i-PyArray_NDIM({i_name})+1]) {{
2212+
Py_CLEAR({output_name});
2213+
break;
2214+
}}
2215+
}}
2216+
}}
22032217
}}
22042218
}}
2205-
}}
2219+
"""
2220+
2221+
return f"""
2222+
{manage_pre_allocated_out}
22062223
{output_name} = (PyArrayObject*)PyArray_TakeFrom(
2207-
{a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE);
2224+
{a_name}, (PyObject*){i_name}, 0, {output_name}, {mode});
22082225
if ({output_name} == NULL) {fail};
22092226
"""
22102227

22112228
def c_code_cache_version(self):
2212-
return (4,)
2229+
return (5,)
2230+
2231+
@staticmethod
2232+
def _idx_may_be_invalid(x, idx) -> bool:
2233+
if idx.type.shape[0] == 0:
2234+
# Empty index is always valid
2235+
return False
2236+
2237+
if x.type.shape[0] is None:
2238+
# We can't know if in index is valid if we don't know the length of x
2239+
return True
2240+
2241+
if not isinstance(idx, Constant):
2242+
# This is conservative, but we don't try to infer lower/upper bound symbolically
2243+
return True
2244+
2245+
shape0 = x.type.shape[0]
2246+
min_idx, max_idx = idx.data.min(), idx.data.max()
2247+
return not (min_idx >= 0 or min_idx >= -shape0) and (
2248+
max_idx < 0 or max_idx < shape0
2249+
)
22132250

22142251

22152252
advanced_subtensor1 = AdvancedSubtensor1()

tests/tensor/test_subtensor.py

+25
Original file line numberDiff line numberDiff line change
@@ -3003,3 +3003,28 @@ def test_flip(size: tuple[int]):
30033003
z = flip(x_pt, axis=list(axes))
30043004
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
30053005
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
3006+
3007+
3008+
class TestBenchmarks:
3009+
@pytest.mark.parametrize(
3010+
"static_shape", (False, True), ids=lambda x: f"static_shape={x}"
3011+
)
3012+
@pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}")
3013+
def test_advanced_subtensor1(self, static_shape, gc, benchmark):
3014+
x = vector("x", shape=(85 if static_shape else None,))
3015+
3016+
x_values = np.random.normal(size=(85,))
3017+
idxs_values = np.arange(85).repeat(11)
3018+
3019+
# With static shape and constant indices we know all idxs are valid
3020+
# And can use faster mode in numpy.take
3021+
out = x[idxs_values]
3022+
3023+
fn = pytensor.function(
3024+
[x],
3025+
pytensor.Out(out, borrow=True),
3026+
on_unused_input="ignore",
3027+
trust_input=True,
3028+
)
3029+
fn.vm.allow_gc = gc
3030+
benchmark(fn, x_values, idxs_values)

0 commit comments

Comments
 (0)