@@ -206,11 +206,11 @@ def test_cross(x1_x2_kw):
206
206
207
207
def exact_cross (a , b ):
208
208
assert a .shape == b .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
209
- return asarray ([
209
+ return asarray (xp . stack ( [
210
210
a [1 ]* b [2 ] - a [2 ]* b [1 ],
211
211
a [2 ]* b [0 ] - a [0 ]* b [2 ],
212
212
a [0 ]* b [1 ] - a [1 ]* b [0 ],
213
- ], dtype = res .dtype )
213
+ ]) , dtype = res .dtype )
214
214
215
215
# We don't want to pass in **kw here because that would pass axis to
216
216
# cross() on a single stack, but the axis is not meaningful on unstacked
@@ -267,7 +267,7 @@ def true_diag(x_stack, offset=0):
267
267
x_stack_diag = [x_stack [i , i + offset ] for i in range (diag_size )]
268
268
else :
269
269
x_stack_diag = [x_stack [i - offset , i ] for i in range (diag_size )]
270
- return asarray (x_stack_diag , dtype = x .dtype )
270
+ return asarray (xp . stack ( x_stack_diag ) if x_stack_diag else [] , dtype = x .dtype )
271
271
272
272
_test_stacks (linalg .diagonal , x , ** kw , res = res , dims = 1 , true_val = true_diag )
273
273
@@ -901,7 +901,9 @@ def true_trace(x_stack, offset=0):
901
901
x_stack_diag = [x_stack [i , i + offset ] for i in range (diag_size )]
902
902
else :
903
903
x_stack_diag = [x_stack [i - offset , i ] for i in range (diag_size )]
904
- return _array_module .sum (asarray (x_stack_diag , dtype = x .dtype ))
904
+ result = xp .asarray (xp .stack (x_stack_diag ) if x_stack_diag else [], dtype = x .dtype )
905
+ return _array_module .sum (result )
906
+
905
907
906
908
_test_stacks (linalg .trace , x , ** kw , res = res , dims = 0 , true_val = true_trace )
907
909
0 commit comments