diff --git a/pytensor/link/jax/dispatch/subtensor.py b/pytensor/link/jax/dispatch/subtensor.py index 90467daec5..1c659be29b 100644 --- a/pytensor/link/jax/dispatch/subtensor.py +++ b/pytensor/link/jax/dispatch/subtensor.py @@ -67,6 +67,9 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): if len(indices) == 1: indices = indices[0] + if isinstance(op, AdvancedIncSubtensor1): + op._check_runtime_broadcasting(node, x, y, indices) + return jax_fn(x, indices, y) return incsubtensor diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 7e1f6ded56..a0361738ae 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -83,7 +83,7 @@ def cholesky(a): @numba_funcify.register(PivotToPermutations) def pivot_to_permutation(op, node, **kwargs): inverse = op.inverse - dtype = node.inputs[0].dtype + dtype = node.outputs[0].dtype @numba_njit def numba_pivot_to_permutation(piv): diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index ee9e183d16..3328ea349c 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -287,11 +287,11 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): inplace = op.inplace set_instead_of_inc = op.set_instead_of_inc x, vals, idxs = node.inputs - # TODO: Add explicit expand_dims in make_node so we don't need to worry about this here - broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0] + broadcast_with_index = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0] + # TODO: Add runtime_broadcast check if set_instead_of_inc: - if broadcast: + if broadcast_with_index: @numba_njit(boundscheck=True) def advancedincsubtensor1_inplace(x, val, idxs): @@ -318,7 +318,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs): x[idx] = val return x else: - if broadcast: + if broadcast_with_index: @numba_njit(boundscheck=True) def advancedincsubtensor1_inplace(x, val, idxs): diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 34358797fb..5dfa7dfa36 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -109,6 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): def adv_set_subtensor(x, y, *indices): check_negative_steps(indices) + if isinstance(op, AdvancedIncSubtensor1): + op._check_runtime_broadcasting(node, x, y, indices) if not inplace: x = x.clone() x[indices] = y.type_as(x) @@ -120,6 +122,8 @@ def adv_set_subtensor(x, y, *indices): def adv_inc_subtensor_no_duplicates(x, y, *indices): check_negative_steps(indices) + if isinstance(op, AdvancedIncSubtensor1): + op._check_runtime_broadcasting(node, x, y, indices) if not inplace: x = x.clone() x[indices] += y.type_as(x) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e0752f14ea..4786b71778 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1634,6 +1634,14 @@ def _check_runtime_broadcast(node, value, shape): if v_static_dim is None and value_dim == 1 and out_dim != 1: raise ValueError(Alloc._runtime_broadcast_error_msg) + @staticmethod + def value_is_scalar_zero(x: TensorVariable) -> bool: + return ( + all(x.type.broadcastable) + and isinstance(x, Constant) + and (x.unique_value == 0) + ) + def perform(self, node, inputs, out_): (out,) = out_ v = inputs[0] @@ -1659,6 +1667,7 @@ def c_code(self, node, name, inp, out, sub): o_static_shape = node.outputs[0].type.shape v_ndim = len(v_static_shape) o_ndim = len(o_static_shape) + is_zero = self.value_is_scalar_zero(node.inputs[0]) assert o_ndim == len(inp[1:]) # Declare variables @@ -1699,16 +1708,18 @@ def c_code(self, node, name, inp, out, sub): {fail} }} }} - + if ({int(is_zero)} && (PyArray_IS_C_CONTIGUOUS({zz}) || PyArray_IS_F_CONTIGUOUS({zz}))){{ + PyArray_FILLWBYTE({zz}, 0); + }} // This function takes care of broadcasting - if (PyArray_CopyInto({zz}, {vv}) == -1) + else if (PyArray_CopyInto({zz}, {vv}) == -1) {fail} """ return code def c_code_cache_version(self): - return (4,) + return (5,) def infer_shape(self, fgraph, node, input_shapes): return [node.inputs[1:]] diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index defb72bfbc..93ed4cec8a 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -1295,12 +1295,26 @@ def local_inplace_setsubtensor(fgraph, node): @node_rewriter([AdvancedIncSubtensor1], inplace=True) def local_inplace_AdvancedIncSubtensor1(fgraph, node): - if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: - new_op = node.op.clone_inplace() - new_node = new_op(*node.inputs) - copy_stack_trace(node.outputs, new_node) - return [new_node] - return False + if node.op.inplace: + return + + x, y, idx = node.inputs + if fgraph.has_destroyers([x]): + # In this case we can't operate inplace, but if x is just an alloc of zeros + # We're better off duplicating it and then acting on it inplace. + if ( + x.owner is not None + and isinstance(x.owner.op, Alloc) + and x.owner.op.value_is_scalar_zero(x.owner.inputs[0]) + ): + x = x.owner.clone().outputs[0] + else: + return None # Inplace isn't valid + + new_op = node.op.clone_inplace() + new_node = new_op(x, y, idx) + copy_stack_trace(node.outputs, new_node) + return [new_node] compile.optdb.register( diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 713e42b0a9..14a6d91a7d 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -604,7 +604,7 @@ def make_node(self, pivots): def perform(self, node, inputs, outputs): [pivots] = inputs - p_inv = np.arange(len(pivots), dtype=pivots.dtype) + p_inv = np.arange(len(pivots), dtype="int64") for i in range(len(pivots)): p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i] @@ -639,7 +639,7 @@ def make_node(self, A): ) LU = matrix(shape=A.type.shape, dtype=A.type.dtype) - pivots = vector(shape=(A.type.shape[0],), dtype="int64") + pivots = vector(shape=(A.type.shape[0],), dtype="int32") return Apply(self, [A], [LU, pivots]) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 9c14f31e1d..278d1e8da6 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2120,16 +2120,12 @@ def make_node(self, x, ilist): out_shape = (ilist_.type.shape[0], *x_.type.shape[1:]) return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=out_shape)()]) - def perform(self, node, inp, out_): + def perform(self, node, inp, output_storage): x, i = inp - (out,) = out_ - # Copy always implied by numpy advanced indexing semantic. - if out[0] is not None and out[0].shape == (len(i),) + x.shape[1:]: - o = out[0] - else: - o = None - out[0] = x.take(i, axis=0, out=o) + # Numpy take is always slower when out is provided + # https://github.com/numpy/numpy/issues/28636 + output_storage[0][0] = x.take(i, axis=0, out=None) def connection_pattern(self, node): rval = [[True], *([False] for _ in node.inputs[1:])] @@ -2174,42 +2170,83 @@ def c_code(self, node, name, input_names, output_names, sub): "c_code defined for AdvancedSubtensor1, not for child class", type(self), ) + x, idxs = node.inputs + if self._idx_may_be_invalid(x, idxs): + mode = "NPY_RAISE" + else: + # We can know ahead of time that all indices are valid, so we can use a faster mode + mode = "NPY_WRAP" # This seems to be faster than NPY_CLIP + a_name, i_name = input_names[0], input_names[1] output_name = output_names[0] fail = sub["fail"] - return f""" - if ({output_name} != NULL) {{ - npy_intp nd, i, *shape; - nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1; - if (PyArray_NDIM({output_name}) != nd) {{ + if mode == "NPY_RAISE": + # numpy_take always makes an intermediate copy if NPY_RAISE which is slower than just allocating a new buffer + # We can remove this special case after https://github.com/numpy/numpy/issues/28636 + manage_pre_allocated_out = f""" + if ({output_name} != NULL) {{ + // Numpy TakeFrom is always slower when copying + // https://github.com/numpy/numpy/issues/28636 Py_CLEAR({output_name}); }} - else {{ - shape = PyArray_DIMS({output_name}); - for (i = 0; i < PyArray_NDIM({i_name}); i++) {{ - if (shape[i] != PyArray_DIMS({i_name})[i]) {{ - Py_CLEAR({output_name}); - break; - }} + """ + else: + manage_pre_allocated_out = f""" + if ({output_name} != NULL) {{ + npy_intp nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1; + if (PyArray_NDIM({output_name}) != nd) {{ + Py_CLEAR({output_name}); }} - if ({output_name} != NULL) {{ - for (; i < nd; i++) {{ - if (shape[i] != PyArray_DIMS({a_name})[ - i-PyArray_NDIM({i_name})+1]) {{ + else {{ + int i; + npy_intp* shape = PyArray_DIMS({output_name}); + for (i = 0; i < PyArray_NDIM({i_name}); i++) {{ + if (shape[i] != PyArray_DIMS({i_name})[i]) {{ Py_CLEAR({output_name}); break; }} }} + if ({output_name} != NULL) {{ + for (; i < nd; i++) {{ + if (shape[i] != PyArray_DIMS({a_name})[i-PyArray_NDIM({i_name})+1]) {{ + Py_CLEAR({output_name}); + break; + }} + }} + }} }} }} - }} + """ + + return f""" + {manage_pre_allocated_out} {output_name} = (PyArrayObject*)PyArray_TakeFrom( - {a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE); + {a_name}, (PyObject*){i_name}, 0, {output_name}, {mode}); if ({output_name} == NULL) {fail}; """ def c_code_cache_version(self): - return (4,) + return (5,) + + @staticmethod + def _idx_may_be_invalid(x, idx) -> bool: + if idx.type.shape[0] == 0: + # Empty index is always valid + return False + + if x.type.shape[0] is None: + # We can't know if in index is valid if we don't know the length of x + return True + + if not isinstance(idx, Constant): + # This is conservative, but we don't try to infer lower/upper bound symbolically + return True + + shape0 = x.type.shape[0] + min_idx, max_idx = idx.data.min(), idx.data.max() + return not (min_idx >= 0 or min_idx >= -shape0) and ( + max_idx < 0 or max_idx < shape0 + ) advanced_subtensor1 = AdvancedSubtensor1() @@ -2225,6 +2262,12 @@ class AdvancedIncSubtensor1(COp): check_input = False params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool) + _runtime_broadcast_error_msg = ( + "Runtime broadcasting not allowed. " + "AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. " + "If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)." + ) + def __init__(self, inplace=False, set_instead_of_inc=False): self.inplace = bool(inplace) self.set_instead_of_inc = bool(set_instead_of_inc) @@ -2296,6 +2339,9 @@ def copy_of_x(self, x): NPY_ARRAY_ENSURECOPY, NULL)""" def c_support_code(self, **kwargs): + if numpy_version < "1.8.0" or using_numpy_2: + return None + types = [ "npy_" + t for t in [ @@ -2486,15 +2532,117 @@ def gen_num(typen): return code def c_code(self, node, name, input_names, output_names, sub): - if numpy_version < "1.8.0" or using_numpy_2: - raise NotImplementedError - x, y, idx = input_names - out = output_names[0] + [out] = output_names copy_of_x = self.copy_of_x(x) params = sub["params"] fail = sub["fail"] + x_, y_, idx_ = node.inputs + y_cdtype = y_.type.dtype_specs()[1] + idx_cdtype = idx_.type.dtype_specs()[1] + out_cdtype = node.outputs[0].type.dtype_specs()[1] + y_bcast = y_.type.broadcastable != idx_.type.broadcastable + if ( + x_.type.ndim == 1 + and y_.type.ndim == 1 + and not y_bcast + and x_.type.dtype not in complex_dtypes + and y_.type.dtype not in complex_dtypes + ): + # Simple implementation for vector x, y cases + idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0) + idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_) + shape0 = x_.type.shape[0] + # This is used to make sure that when we trust the indices to be valid + # we are not fooled by a wrong static shape + # We mention x to the user in error messages but we work (and make checks) on out, + # which should be x or a copy of it + unexpected_shape0 = ( + f"PyArray_SHAPE({out})[0] != {shape0}" if shape0 is not None else "0" + ) + + op = "=" if self.set_instead_of_inc else "+=" + code = f""" + if ({params}->inplace) + {{ + if ({x} != {out}) + {{ + Py_XDECREF({out}); + Py_INCREF({x}); + {out} = {x}; + }} + }} + else + {{ + Py_XDECREF({out}); + {out} = {copy_of_x}; + if (!{out}) {{ + // Exception already set + {fail} + }} + }} + + if (PyArray_NDIM({out}) != 1) {{ + PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) ndim should be 1, got %d", PyArray_NDIM({out})); + {fail} + }} + if ({unexpected_shape0}) {{ + PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) shape should be {shape0}, got %d", PyArray_SHAPE({out})[0]); + {fail} + }} + if (PyArray_NDIM({idx}) != 1) {{ + PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim should be 1, got %d", PyArray_NDIM({idx})); + {fail} + }} + if (PyArray_NDIM({y}) != 1) {{ + PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: second input (y) ndim should be 1, got %d", PyArray_NDIM({y})); + {fail} + }} + if (PyArray_SHAPE({y})[0] != PyArray_SHAPE({idx})[0]) {{ + if ((PyArray_NDIM({y}) == 1) && (PyArray_SHAPE({y})[0] == 1)){{ + PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}"); + }} else {{ + PyErr_Format(PyExc_ValueError, + "AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match: %d, %d", + PyArray_SHAPE({y})[0], PyArray_SHAPE({idx})[0]); + }} + {fail} + }} + + {{ + npy_intp out_shape0 = PyArray_SHAPE({out})[0]; + {out_cdtype}* out_data = ({out_cdtype}*)PyArray_DATA({out}); + {y_cdtype}* y_data = ({y_cdtype}*)PyArray_DATA({y}); + {idx_cdtype}* idx_data = ({idx_cdtype}*)PyArray_DATA({idx}); + npy_intp n = PyArray_SHAPE({idx})[0]; + npy_intp out_jump = PyArray_STRIDES({out})[0] / PyArray_ITEMSIZE({out}); + npy_intp y_jump = PyArray_STRIDES({y})[0] / PyArray_ITEMSIZE({y}); + npy_intp idx_jump = PyArray_STRIDES({idx})[0] / PyArray_ITEMSIZE({idx}); + + for(int i = 0; i < n; i++){{ + {idx_cdtype} idx = idx_data[i * idx_jump]; + if ({int(idx_may_be_neg)}){{ + if (idx < 0) {{ + idx += out_shape0; + }} + }} + if ({int(idx_may_be_invalid)}){{ + if ((idx < 0) || (idx >= out_shape0)) {{ + PyErr_Format(PyExc_IndexError,"index %d out of bounds for array with shape %d", idx_data[i * idx_jump], out_shape0); + {fail} + }} + }} + out_data[idx * out_jump] {op} y_data[i * y_jump]; + }} + + }} + """ + return code + + if numpy_version < "1.8.0" or using_numpy_2: + raise NotImplementedError + return f""" PyObject* rval = NULL; if ({params}->inplace) @@ -2522,14 +2670,37 @@ def c_code(self, node, name, input_names, output_names, sub): """ def c_code_cache_version(self): - return (8,) + return (9,) + + def _check_runtime_broadcasting( + self, node: Apply, x: np.ndarray, y: np.ndarray, idx: np.ndarray + ) -> None: + if y.ndim > 0: + y_pt_bcast = node.inputs[1].broadcastable # type: ignore + + if not y_pt_bcast[0] and y.shape[0] == 1 and y.shape[0] != idx.shape[0]: + # Attempting to broadcast with index + raise ValueError(self._runtime_broadcast_error_msg) + if any( + not y_bcast and y_dim == 1 and y_dim != x_dim + for y_bcast, y_dim, x_dim in zip( + reversed(y_pt_bcast), + reversed(y.shape), + reversed(x.shape), + strict=False, + ) + ): + # Attempting to broadcast with buffer + raise ValueError(self._runtime_broadcast_error_msg) + + def perform(self, node, inputs, output_storage): + x, y, idx = inputs - def perform(self, node, inp, out_): - x, y, idx = inp - (out,) = out_ if not self.inplace: x = x.copy() + self._check_runtime_broadcasting(node, x, y, idx) + if self.set_instead_of_inc: x[idx] = y else: @@ -2537,7 +2708,7 @@ def perform(self, node, inp, out_): # many times: it does it only once. np.add.at(x, idx, y) - out[0] = x + output_storage[0][0] = x def infer_shape(self, fgraph, node, ishapes): x, y, ilist = ishapes diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 78ec97eff3..d10bb1dd2e 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -5,6 +5,7 @@ import numpy as np import pytest from numpy.testing import assert_array_equal +from packaging import version import pytensor import pytensor.scalar as scal @@ -26,7 +27,7 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import exp, isinf, lt, switch from pytensor.tensor.math import sum as pt_sum -from pytensor.tensor.shape import specify_shape +from pytensor.tensor.shape import specify_broadcastable, specify_shape from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -1101,9 +1102,9 @@ def grad_list_(self, idxs, data): n = self.shared(data) for idx in idxs: - # Should stay on the cpu. - idx_ = shared(np.asarray(idx)) - t = n[idx_] + idx_np = np.asarray(idx) + idx_pt = shared(idx_np, shape=(1 if idx_np.shape[0] == 1 else None,)) + t = n[idx_pt] gn = pytensor.grad(pt_sum(exp(t)), n) f = self.function([], [gn, gn.shape], op=AdvancedIncSubtensor1) topo = f.maker.fgraph.toposort() @@ -1126,13 +1127,13 @@ def grad_list_(self, idxs, data): assert np.allclose(gshape, data.shape) def fct(t): - return pt_sum(t[idx_]) + return pt_sum(t[idx_pt]) utt.verify_grad(fct, [data], mode=self.mode) # Test the grad of the grad (e.i. AdvancedIncSubtensor1.grad) def fct2(t): - return pytensor.grad(pt_sum(t[idx_]), t) + return pytensor.grad(pt_sum(t[idx_pt]), t) utt.verify_grad(fct2, [data], mode=self.mode) @@ -1143,7 +1144,9 @@ def fct2(t): ops = subtensor_ops if idx is idxs[0]: # TODO FIXME: This is a very poorly specified test. - f = self.function([], [gn.shape, n[idx_].shape], op=ops, N=0, N_fast=0) + f = self.function( + [], [gn.shape, n[idx_pt].shape], op=ops, N=0, N_fast=0 + ) f() def test_wrong_exception_regression(self): @@ -1231,10 +1234,7 @@ def test_advanced1_inc_and_set(self): data_num_init = np.arange(data_size, dtype=self.dtype) data_num_init = data_num_init.reshape(data_shape) inc_shapes = [data_shape[i:] for i in range(0, len(data_shape) + 1)] - # Test broadcasting of y. - inc_shapes += [(1,) + inc_shapes[-1][1:]] for inc_shape in inc_shapes: - inc_n_dims = len(inc_shape) # We copy the numeric value to be 100% sure there is no # risk of accidentally sharing it. data_num = data_num_init.copy() @@ -1263,10 +1263,7 @@ def test_advanced1_inc_and_set(self): replace=(not set_instead_of_inc), ) idx_num = idx_num.astype("int64") - # Symbolic variable with increment value. - inc_var = TensorType( - shape=(None,) * inc_n_dims, dtype=self.dtype - )() + # Trick for the case where `inc_shape` is the same as # `data_shape`: what we actually want is the first # shape element to be equal to the number of rows to @@ -1275,6 +1272,15 @@ def test_advanced1_inc_and_set(self): len(inc_shapes) == 0 or inc_shape[0] != 1 ): inc_shape = (n_to_inc,) + inc_shape[1:] + + # Symbolic variable with increment value. + inc_var_static_shape = tuple( + 1 if dim_length == 1 else None for dim_length in inc_shape + ) + inc_var = TensorType( + shape=inc_var_static_shape, dtype=self.dtype + )() + # The param dtype is needed when inc_shape is empty. # By default, it would return a float and rng.uniform # with NumPy 1.10 will raise a Deprecation warning. @@ -1341,6 +1347,31 @@ def test_advanced1_inc_and_set(self): # you enable the debug code above. assert np.allclose(f_out, output_num), (params, f_out, output_num) + @pytest.mark.skipif( + version.parse(np.__version__) < version.parse("2.0"), + reason="Legacy C-implementation did not check for runtime broadcast", + ) + @pytest.mark.parametrize("func", (advanced_inc_subtensor1, advanced_set_subtensor1)) + def test_advanced1_inc_runtime_broadcast(self, func): + y = matrix("y", dtype="float64", shape=(None, None)) + + x = ptb.zeros((10, 5)) + idxs = np.repeat(np.arange(10), 2) + out = func(x, y, idxs) + + f = function([y], out) + f(np.ones((20, 5))) # Fine + with pytest.raises( + ValueError, + match="Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked", + ): + f(np.ones((1, 5))) + with pytest.raises( + ValueError, + match="Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked", + ): + f(np.ones((20, 1))) + def test_adv_constant_arg(self): # Test case provided (and bug detected, gh-607) by John Salvatier m = matrix("m") @@ -2398,7 +2429,11 @@ def test_AdvancedIncSubtensor1(self): aivec_val = [2, 3] self._compile_and_check( [admat, bdmat], - [advanced_set_subtensor1(admat, bdmat, aivec_val)], + [ + advanced_set_subtensor1( + admat, specify_broadcastable(bdmat, 0), aivec_val + ) + ], [admat_val, [[1, 2, 3, 4]]], AdvancedIncSubtensor1, ) @@ -2425,7 +2460,11 @@ def test_AdvancedIncSubtensor1(self): aivec_val = [2, 3] self._compile_and_check( [adtens4, bdtens4], - [advanced_set_subtensor1(adtens4, bdtens4, aivec_val)], + [ + advanced_set_subtensor1( + adtens4, specify_broadcastable(bdtens4, 0, 1, 2), aivec_val + ) + ], [adtens4_val, [[[[1, 2, 3, 4, 5]]]]], AdvancedIncSubtensor1, warn=False, @@ -2476,7 +2515,11 @@ def test_AdvancedIncSubtensor1(self): aivec_val = [2, 3] self._compile_and_check( [adtens4, bdtens4], - [advanced_set_subtensor1(adtens4, bdtens4, aivec_val)], + [ + advanced_set_subtensor1( + adtens4, specify_broadcastable(bdtens4, 1, 2), aivec_val + ) + ], [adtens4_val, [[[[1, 2, 3, 4, 5]]], [[[6, 7, 8, 9, 10]]]]], AdvancedIncSubtensor1, warn=False, @@ -3003,3 +3046,54 @@ def test_flip(size: tuple[int]): z = flip(x_pt, axis=list(axes)) f = pytensor.function([x_pt], z, mode="FAST_COMPILE") np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) + + +class TestBenchmarks: + @pytest.mark.parametrize( + "static_shape", (False, True), ids=lambda x: f"static_shape={x}" + ) + @pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}") + def test_advanced_subtensor1(self, static_shape, gc, benchmark): + x = vector("x", shape=(85 if static_shape else None,)) + + x_values = np.random.normal(size=(85,)) + idxs_values = np.arange(85).repeat(11) + + # With static shape and constant indices we know all idxs are valid + # And can use faster mode in numpy.take + out = x[idxs_values] + + fn = pytensor.function( + [x], + pytensor.Out(out, borrow=True), + on_unused_input="ignore", + trust_input=True, + ) + fn.vm.allow_gc = gc + benchmark(fn, x_values, idxs_values) + + @pytest.mark.parametrize( + "static_shape", (False, True), ids=lambda x: f"static_shape={x}" + ) + @pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}") + @pytest.mark.parametrize("func", (inc_subtensor, set_subtensor)) + def test_advanced_incsubtensor1(self, func, static_shape, gc, benchmark): + x = vector("x", shape=(85 if static_shape else None,)) + x_values = np.zeros((85,)) + buffer = ptb.zeros_like(x) + y_values = np.random.normal(size=(85 * 11,)) + idxs_values = np.arange(85).repeat(11) + + # With static shape and constant indices we know all idxs are valid + # Reuse same buffer of zeros, to check we rather allocate twice than copy inside IncSubtensor + out1 = func(buffer[idxs_values], y_values) + out2 = func(buffer[idxs_values[::-1]], y_values) + + fn = pytensor.function( + [x], + [pytensor.Out(out1, borrow=True), pytensor.Out(out2, borrow=True)], + on_unused_input="ignore", + trust_input=True, + ) + fn.vm.allow_gc = gc + benchmark(fn, x_values)