@@ -2262,6 +2262,12 @@ class AdvancedIncSubtensor1(COp):
2262
2262
check_input = False
2263
2263
params_type = ParamsType (inplace = ps .bool , set_instead_of_inc = ps .bool )
2264
2264
2265
+ _runtime_broadcast_error_msg = (
2266
+ "Runtime broadcasting not allowed. "
2267
+ "AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
2268
+ "If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
2269
+ )
2270
+
2265
2271
def __init__ (self , inplace = False , set_instead_of_inc = False ):
2266
2272
self .inplace = bool (inplace )
2267
2273
self .set_instead_of_inc = bool (set_instead_of_inc )
@@ -2333,6 +2339,9 @@ def copy_of_x(self, x):
2333
2339
NPY_ARRAY_ENSURECOPY, NULL)"""
2334
2340
2335
2341
def c_support_code (self , ** kwargs ):
2342
+ if numpy_version < "1.8.0" or using_numpy_2 :
2343
+ return None
2344
+
2336
2345
types = [
2337
2346
"npy_" + t
2338
2347
for t in [
@@ -2523,15 +2532,104 @@ def gen_num(typen):
2523
2532
return code
2524
2533
2525
2534
def c_code (self , node , name , input_names , output_names , sub ):
2526
- if numpy_version < "1.8.0" or using_numpy_2 :
2527
- raise NotImplementedError
2528
-
2529
2535
x , y , idx = input_names
2530
- out = output_names [ 0 ]
2536
+ [ out ] = output_names
2531
2537
copy_of_x = self .copy_of_x (x )
2532
2538
params = sub ["params" ]
2533
2539
fail = sub ["fail" ]
2534
2540
2541
+ x_ , y_ , idx_ = node .inputs
2542
+ y_dtype = y_ .type .dtype_specs ()[1 ]
2543
+ idx_dtype = idx_ .type .dtype_specs ()[1 ]
2544
+ out_dtype = node .outputs [0 ].type .dtype_specs ()[1 ]
2545
+ y_bcast = y_ .type .broadcastable != idx_ .type .broadcastable
2546
+ if (
2547
+ x_ .type .ndim == 1
2548
+ and x_ .type .dtype not in complex_dtypes
2549
+ and not y_bcast
2550
+ and y_ .type .dtype not in complex_dtypes
2551
+ ):
2552
+ # Simple implementation for vector x, y cases
2553
+ idx_may_be_neg = not (isinstance (idx_ , Constant ) and idx_ .data .min () >= 0 )
2554
+ idx_may_be_invalid = AdvancedSubtensor1 ._idx_may_be_invalid (x_ , idx_ )
2555
+ shape0 = x_ .type .shape [0 ]
2556
+ # This is used to make sure that when we trust the indices to be valid
2557
+ # we are not fooled by a wrong static shape
2558
+ unexpected_shape0 = (
2559
+ f"PyArray_SHAPE({ x } )[0] != { shape0 } " if shape0 is not None else "0"
2560
+ )
2561
+
2562
+ op = "=" if self .set_instead_of_inc else "+="
2563
+ code = f"""
2564
+ if ({ params } ->inplace)
2565
+ {{
2566
+ if ({ x } != { out } )
2567
+ {{
2568
+ Py_XDECREF({ out } );
2569
+ Py_INCREF({ x } );
2570
+ { out } = { x } ;
2571
+ }}
2572
+ }}
2573
+ else
2574
+ {{
2575
+ Py_XDECREF({ out } );
2576
+ { out } = { copy_of_x } ;
2577
+ if (!{ out } ) {{
2578
+ // Exception already set
2579
+ { fail }
2580
+ }}
2581
+ }}
2582
+
2583
+ if ((PyArray_NDIM({ out } ) != 1) || ({ unexpected_shape0 } )) {{
2584
+ PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: fist input (x) does not have right shape or ndim");
2585
+ { fail }
2586
+ }}
2587
+ if (PyArray_NDIM({ idx } ) != 1) {{
2588
+ PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim != 1");
2589
+ { fail }
2590
+ }}
2591
+ if ((PyArray_NDIM({ y } ) != 1) || (PyArray_SHAPE({ y } )[0] != PyArray_SHAPE({ idx } )[0])) {{
2592
+ if ((PyArray_NDIM({ y } ) == 1) && (PyArray_SHAPE({ y } )[0] == 1)){{
2593
+ PyErr_SetString(PyExc_ValueError, "{ self ._runtime_broadcast_error_msg } ");
2594
+ }} else {{
2595
+ PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match");
2596
+ }}
2597
+ { fail }
2598
+ }}
2599
+
2600
+ {{
2601
+ npy_intp out_shape0 = PyArray_SHAPE({ out } )[0];
2602
+ { out_dtype } * out_data = ({ out_dtype } *)PyArray_DATA({ out } );
2603
+ { y_dtype } * y_data = ({ y_dtype } *)PyArray_DATA({ y } );
2604
+ { idx_dtype } * idx_data = ({ idx_dtype } *)PyArray_DATA({ idx } );
2605
+ npy_intp n = PyArray_SHAPE({ idx } )[0];
2606
+ npy_intp out_jump = PyArray_STRIDES({ out } )[0] / PyArray_ITEMSIZE({ out } );
2607
+ npy_intp y_jump = PyArray_STRIDES({ y } )[0] / PyArray_ITEMSIZE({ y } );
2608
+ npy_intp idx_jump = PyArray_STRIDES({ idx } )[0] / PyArray_ITEMSIZE({ idx } );
2609
+
2610
+ for(int i = 0; i < n; i++){{
2611
+ { idx_dtype } idx = idx_data[i * idx_jump];
2612
+ if ({ int (idx_may_be_neg )} ){{
2613
+ if (idx < 0) {{
2614
+ idx += out_shape0;
2615
+ }}
2616
+ }}
2617
+ if ({ int (idx_may_be_invalid )} ){{
2618
+ if ((idx < 0) || (idx >= out_shape0)) {{
2619
+ PyErr_Format(PyExc_IndexError,"index out of bounds");
2620
+ { fail }
2621
+ }}
2622
+ }}
2623
+ out_data[idx * out_jump] { op } y_data[i * y_jump];
2624
+ }}
2625
+
2626
+ }}
2627
+ """
2628
+ return code
2629
+
2630
+ if numpy_version < "1.8.0" or using_numpy_2 :
2631
+ raise NotImplementedError
2632
+
2535
2633
return f"""
2536
2634
PyObject* rval = NULL;
2537
2635
if ({ params } ->inplace)
@@ -2559,22 +2657,43 @@ def c_code(self, node, name, input_names, output_names, sub):
2559
2657
"""
2560
2658
2561
2659
def c_code_cache_version (self ):
2562
- return (8 ,)
2660
+ return (9 ,)
2661
+
2662
+ def _check_runtime_broadcasting (self , node , x , y , idx ):
2663
+ if y .ndim > 0 :
2664
+ y_pt_bcast = node .inputs [1 ].broadcastable
2665
+
2666
+ if not y_pt_bcast [0 ] and y .shape [0 ] == 1 and y .shape [0 ] != idx .shape [0 ]:
2667
+ # Attempting to broadcast with index
2668
+ raise ValueError (self ._runtime_broadcast_error_msg )
2669
+ if any (
2670
+ not y_bcast and y_dim == 1 and y_dim != x_dim
2671
+ for y_bcast , y_dim , x_dim in zip (
2672
+ reversed (y_pt_bcast ),
2673
+ reversed (y .shape ),
2674
+ reversed (x .shape ),
2675
+ strict = False ,
2676
+ )
2677
+ ):
2678
+ # Attempting to broadcast with buffer
2679
+ raise ValueError (self ._runtime_broadcast_error_msg )
2680
+
2681
+ def perform (self , node , inputs , output_storage ):
2682
+ x , y , idx = inputs
2563
2683
2564
- def perform (self , node , inp , out_ ):
2565
- x , y , idx = inp
2566
- (out ,) = out_
2567
2684
if not self .inplace :
2568
2685
x = x .copy ()
2569
2686
2687
+ self ._check_runtime_broadcasting (node , x , y , idx )
2688
+
2570
2689
if self .set_instead_of_inc :
2571
2690
x [idx ] = y
2572
2691
else :
2573
2692
# In Numpy, `x[idx] += y` doesn't work if the same index is present
2574
2693
# many times: it does it only once.
2575
2694
np .add .at (x , idx , y )
2576
2695
2577
- out [0 ] = x
2696
+ output_storage [ 0 ] [0 ] = x
2578
2697
2579
2698
def infer_shape (self , fgraph , node , ishapes ):
2580
2699
x , y , ilist = ishapes
0 commit comments