Skip to content

Commit 9e47a30

Browse files
committed
Specialized C-impl for vector AdvancedIncSubtensor1
Also add checks for runtime broadcast
1 parent e7b7eaf commit 9e47a30

File tree

5 files changed

+225
-30
lines changed

5 files changed

+225
-30
lines changed

pytensor/link/jax/dispatch/subtensor.py

+3
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
6767
if len(indices) == 1:
6868
indices = indices[0]
6969

70+
if isinstance(op, AdvancedIncSubtensor1):
71+
op._check_runtime_broadcasting(x, y, indices)
72+
7073
return jax_fn(x, indices, y)
7174

7275
return incsubtensor

pytensor/link/numba/dispatch/subtensor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,11 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
287287
inplace = op.inplace
288288
set_instead_of_inc = op.set_instead_of_inc
289289
x, vals, idxs = node.inputs
290-
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
291-
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
290+
broadcast_with_index = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
291+
# TODO: Add runtime_broadcast check
292292

293293
if set_instead_of_inc:
294-
if broadcast:
294+
if broadcast_with_index:
295295

296296
@numba_njit(boundscheck=True)
297297
def advancedincsubtensor1_inplace(x, val, idxs):
@@ -318,7 +318,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
318318
x[idx] = val
319319
return x
320320
else:
321-
if broadcast:
321+
if broadcast_with_index:
322322

323323
@numba_njit(boundscheck=True)
324324
def advancedincsubtensor1_inplace(x, val, idxs):

pytensor/link/pytorch/dispatch/subtensor.py

+4
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
109109

110110
def adv_set_subtensor(x, y, *indices):
111111
check_negative_steps(indices)
112+
if isinstance(op, AdvancedIncSubtensor1):
113+
op._check_runtime_broadcasting(x, y, indices)
112114
if not inplace:
113115
x = x.clone()
114116
x[indices] = y.type_as(x)
@@ -120,6 +122,8 @@ def adv_set_subtensor(x, y, *indices):
120122

121123
def adv_inc_subtensor_no_duplicates(x, y, *indices):
122124
check_negative_steps(indices)
125+
if isinstance(op, AdvancedIncSubtensor1):
126+
op._check_runtime_broadcasting(x, y, indices)
123127
if not inplace:
124128
x = x.clone()
125129
x[indices] += y.type_as(x)

pytensor/tensor/subtensor.py

+128-9
Original file line numberDiff line numberDiff line change
@@ -2262,6 +2262,12 @@ class AdvancedIncSubtensor1(COp):
22622262
check_input = False
22632263
params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool)
22642264

2265+
_runtime_broadcast_error_msg = (
2266+
"Runtime broadcasting not allowed. "
2267+
"AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
2268+
"If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
2269+
)
2270+
22652271
def __init__(self, inplace=False, set_instead_of_inc=False):
22662272
self.inplace = bool(inplace)
22672273
self.set_instead_of_inc = bool(set_instead_of_inc)
@@ -2333,6 +2339,9 @@ def copy_of_x(self, x):
23332339
NPY_ARRAY_ENSURECOPY, NULL)"""
23342340

23352341
def c_support_code(self, **kwargs):
2342+
if numpy_version < "1.8.0" or using_numpy_2:
2343+
return None
2344+
23362345
types = [
23372346
"npy_" + t
23382347
for t in [
@@ -2523,15 +2532,104 @@ def gen_num(typen):
25232532
return code
25242533

25252534
def c_code(self, node, name, input_names, output_names, sub):
2526-
if numpy_version < "1.8.0" or using_numpy_2:
2527-
raise NotImplementedError
2528-
25292535
x, y, idx = input_names
2530-
out = output_names[0]
2536+
[out] = output_names
25312537
copy_of_x = self.copy_of_x(x)
25322538
params = sub["params"]
25332539
fail = sub["fail"]
25342540

2541+
x_, y_, idx_ = node.inputs
2542+
y_dtype = y_.type.dtype_specs()[1]
2543+
idx_dtype = idx_.type.dtype_specs()[1]
2544+
out_dtype = node.outputs[0].type.dtype_specs()[1]
2545+
y_bcast = y_.type.broadcastable != idx_.type.broadcastable
2546+
if (
2547+
x_.type.ndim == 1
2548+
and x_.type.dtype not in complex_dtypes
2549+
and not y_bcast
2550+
and y_.type.dtype not in complex_dtypes
2551+
):
2552+
# Simple implementation for vector x, y cases
2553+
idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0)
2554+
idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_)
2555+
shape0 = x_.type.shape[0]
2556+
# This is used to make sure that when we trust the indices to be valid
2557+
# we are not fooled by a wrong static shape
2558+
unexpected_shape0 = (
2559+
f"PyArray_SHAPE({x})[0] != {shape0}" if shape0 is not None else "0"
2560+
)
2561+
2562+
op = "=" if self.set_instead_of_inc else "+="
2563+
code = f"""
2564+
if ({params}->inplace)
2565+
{{
2566+
if ({x} != {out})
2567+
{{
2568+
Py_XDECREF({out});
2569+
Py_INCREF({x});
2570+
{out} = {x};
2571+
}}
2572+
}}
2573+
else
2574+
{{
2575+
Py_XDECREF({out});
2576+
{out} = {copy_of_x};
2577+
if (!{out}) {{
2578+
// Exception already set
2579+
{fail}
2580+
}}
2581+
}}
2582+
2583+
if ((PyArray_NDIM({out}) != 1) || ({unexpected_shape0})) {{
2584+
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: fist input (x) does not have right shape or ndim");
2585+
{fail}
2586+
}}
2587+
if (PyArray_NDIM({idx}) != 1) {{
2588+
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim != 1");
2589+
{fail}
2590+
}}
2591+
if ((PyArray_NDIM({y}) != 1) || (PyArray_SHAPE({y})[0] != PyArray_SHAPE({idx})[0])) {{
2592+
if ((PyArray_NDIM({y}) == 1) && (PyArray_SHAPE({y})[0] == 1)){{
2593+
PyErr_SetString(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
2594+
}} else {{
2595+
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match");
2596+
}}
2597+
{fail}
2598+
}}
2599+
2600+
{{
2601+
npy_intp out_shape0 = PyArray_SHAPE({out})[0];
2602+
{out_dtype}* out_data = ({out_dtype}*)PyArray_DATA({out});
2603+
{y_dtype}* y_data = ({y_dtype}*)PyArray_DATA({y});
2604+
{idx_dtype}* idx_data = ({idx_dtype}*)PyArray_DATA({idx});
2605+
npy_intp n = PyArray_SHAPE({idx})[0];
2606+
npy_intp out_jump = PyArray_STRIDES({out})[0] / PyArray_ITEMSIZE({out});
2607+
npy_intp y_jump = PyArray_STRIDES({y})[0] / PyArray_ITEMSIZE({y});
2608+
npy_intp idx_jump = PyArray_STRIDES({idx})[0] / PyArray_ITEMSIZE({idx});
2609+
2610+
for(int i = 0; i < n; i++){{
2611+
{idx_dtype} idx = idx_data[i * idx_jump];
2612+
if ({int(idx_may_be_neg)}){{
2613+
if (idx < 0) {{
2614+
idx += out_shape0;
2615+
}}
2616+
}}
2617+
if ({int(idx_may_be_invalid)}){{
2618+
if ((idx < 0) || (idx >= out_shape0)) {{
2619+
PyErr_Format(PyExc_IndexError,"index out of bounds");
2620+
{fail}
2621+
}}
2622+
}}
2623+
out_data[idx * out_jump] {op} y_data[i * y_jump];
2624+
}}
2625+
2626+
}}
2627+
"""
2628+
return code
2629+
2630+
if numpy_version < "1.8.0" or using_numpy_2:
2631+
raise NotImplementedError
2632+
25352633
return f"""
25362634
PyObject* rval = NULL;
25372635
if ({params}->inplace)
@@ -2559,22 +2657,43 @@ def c_code(self, node, name, input_names, output_names, sub):
25592657
"""
25602658

25612659
def c_code_cache_version(self):
2562-
return (8,)
2660+
return (9,)
2661+
2662+
def _check_runtime_broadcasting(self, node, x, y, idx):
2663+
if y.ndim > 0:
2664+
y_pt_bcast = node.inputs[1].broadcastable
2665+
2666+
if not y_pt_bcast[0] and y.shape[0] == 1 and y.shape[0] != idx.shape[0]:
2667+
# Attempting to broadcast with index
2668+
raise ValueError(self._runtime_broadcast_error_msg)
2669+
if any(
2670+
not y_bcast and y_dim == 1 and y_dim != x_dim
2671+
for y_bcast, y_dim, x_dim in zip(
2672+
reversed(y_pt_bcast),
2673+
reversed(y.shape),
2674+
reversed(x.shape),
2675+
strict=False,
2676+
)
2677+
):
2678+
# Attempting to broadcast with buffer
2679+
raise ValueError(self._runtime_broadcast_error_msg)
2680+
2681+
def perform(self, node, inputs, output_storage):
2682+
x, y, idx = inputs
25632683

2564-
def perform(self, node, inp, out_):
2565-
x, y, idx = inp
2566-
(out,) = out_
25672684
if not self.inplace:
25682685
x = x.copy()
25692686

2687+
self._check_runtime_broadcasting(node, x, y, idx)
2688+
25702689
if self.set_instead_of_inc:
25712690
x[idx] = y
25722691
else:
25732692
# In Numpy, `x[idx] += y` doesn't work if the same index is present
25742693
# many times: it does it only once.
25752694
np.add.at(x, idx, y)
25762695

2577-
out[0] = x
2696+
output_storage[0][0] = x
25782697

25792698
def infer_shape(self, fgraph, node, ishapes):
25802699
x, y, ilist = ishapes

0 commit comments

Comments
 (0)