@@ -2120,16 +2120,12 @@ def make_node(self, x, ilist):
2120
2120
out_shape = (ilist_ .type .shape [0 ], * x_ .type .shape [1 :])
2121
2121
return Apply (self , [x_ , ilist_ ], [TensorType (dtype = x .dtype , shape = out_shape )()])
2122
2122
2123
- def perform (self , node , inp , out_ ):
2123
+ def perform (self , node , inp , output_storage ):
2124
2124
x , i = inp
2125
- (out ,) = out_
2126
- # Copy always implied by numpy advanced indexing semantic.
2127
- if out [0 ] is not None and out [0 ].shape == (len (i ),) + x .shape [1 :]:
2128
- o = out [0 ]
2129
- else :
2130
- o = None
2131
2125
2132
- out [0 ] = x .take (i , axis = 0 , out = o )
2126
+ # Numpy take is always slower when out is provided
2127
+ # https://github.com/numpy/numpy/issues/28636
2128
+ output_storage [0 ][0 ] = x .take (i , axis = 0 , out = None )
2133
2129
2134
2130
def connection_pattern (self , node ):
2135
2131
rval = [[True ], * ([False ] for _ in node .inputs [1 :])]
@@ -2174,42 +2170,83 @@ def c_code(self, node, name, input_names, output_names, sub):
2174
2170
"c_code defined for AdvancedSubtensor1, not for child class" ,
2175
2171
type (self ),
2176
2172
)
2173
+ x , idxs = node .inputs
2174
+ if self ._idx_may_be_invalid (x , idxs ):
2175
+ mode = "NPY_RAISE"
2176
+ else :
2177
+ # We can know ahead of time that all indices are valid, so we can use a faster mode
2178
+ mode = "NPY_WRAP" # This seems to be faster than NPY_CLIP
2179
+
2177
2180
a_name , i_name = input_names [0 ], input_names [1 ]
2178
2181
output_name = output_names [0 ]
2179
2182
fail = sub ["fail" ]
2180
- return f"""
2181
- if ({ output_name } != NULL) {{
2182
- npy_intp nd, i, *shape;
2183
- nd = PyArray_NDIM({ a_name } ) + PyArray_NDIM({ i_name } ) - 1;
2184
- if (PyArray_NDIM({ output_name } ) != nd) {{
2183
+ if mode == "NPY_RAISE" :
2184
+ # numpy_take always makes an intermediate copy if NPY_RAISE which is slower than just allocating a new buffer
2185
+ # We can remove this special case after https://github.com/numpy/numpy/issues/28636
2186
+ manage_pre_allocated_out = f"""
2187
+ if ({ output_name } != NULL) {{
2188
+ // Numpy TakeFrom is always slower when copying
2189
+ // https://github.com/numpy/numpy/issues/28636
2185
2190
Py_CLEAR({ output_name } );
2186
2191
}}
2187
- else {{
2188
- shape = PyArray_DIMS( { output_name } );
2189
- for (i = 0; i < PyArray_NDIM( { i_name } ); i++) {{
2190
- if (shape[i] != PyArray_DIMS( { i_name } )[i] ) {{
2191
- Py_CLEAR( { output_name } ) ;
2192
- break;
2193
- }}
2192
+ """
2193
+ else :
2194
+ manage_pre_allocated_out = f"""
2195
+ if ({ output_name } != NULL ) {{
2196
+ npy_intp nd = PyArray_NDIM( { a_name } ) + PyArray_NDIM( { i_name } ) - 1 ;
2197
+ if (PyArray_NDIM( { output_name } ) != nd) {{
2198
+ Py_CLEAR( { output_name } );
2194
2199
}}
2195
- if ({ output_name } != NULL) {{
2196
- for (; i < nd; i++) {{
2197
- if (shape[i] != PyArray_DIMS({ a_name } )[
2198
- i-PyArray_NDIM({ i_name } )+1]) {{
2200
+ else {{
2201
+ int i;
2202
+ npy_intp* shape = PyArray_DIMS({ output_name } );
2203
+ for (i = 0; i < PyArray_NDIM({ i_name } ); i++) {{
2204
+ if (shape[i] != PyArray_DIMS({ i_name } )[i]) {{
2199
2205
Py_CLEAR({ output_name } );
2200
2206
break;
2201
2207
}}
2202
2208
}}
2209
+ if ({ output_name } != NULL) {{
2210
+ for (; i < nd; i++) {{
2211
+ if (shape[i] != PyArray_DIMS({ a_name } )[i-PyArray_NDIM({ i_name } )+1]) {{
2212
+ Py_CLEAR({ output_name } );
2213
+ break;
2214
+ }}
2215
+ }}
2216
+ }}
2203
2217
}}
2204
2218
}}
2205
- }}
2219
+ """
2220
+
2221
+ return f"""
2222
+ { manage_pre_allocated_out }
2206
2223
{ output_name } = (PyArrayObject*)PyArray_TakeFrom(
2207
- { a_name } , (PyObject*){ i_name } , 0, { output_name } , NPY_RAISE );
2224
+ { a_name } , (PyObject*){ i_name } , 0, { output_name } , { mode } );
2208
2225
if ({ output_name } == NULL) { fail } ;
2209
2226
"""
2210
2227
2211
2228
def c_code_cache_version (self ):
2212
- return (4 ,)
2229
+ return (5 ,)
2230
+
2231
+ @staticmethod
2232
+ def _idx_may_be_invalid (x , idx ) -> bool :
2233
+ if idx .type .shape [0 ] == 0 :
2234
+ # Empty index is always valid
2235
+ return False
2236
+
2237
+ if x .type .shape [0 ] is None :
2238
+ # We can't know if in index is valid if we don't know the length of x
2239
+ return True
2240
+
2241
+ if not isinstance (idx , Constant ):
2242
+ # This is conservative, but we don't try to infer lower/upper bound symbolically
2243
+ return True
2244
+
2245
+ shape0 = x .type .shape [0 ]
2246
+ min_idx , max_idx = idx .data .min (), idx .data .max ()
2247
+ return not (min_idx >= 0 or min_idx >= - shape0 ) and (
2248
+ max_idx < 0 or max_idx < shape0
2249
+ )
2213
2250
2214
2251
2215
2252
advanced_subtensor1 = AdvancedSubtensor1 ()
0 commit comments