Skip to content

Commit c1e5cda

Browse files
authoredApr 9, 2025··
MAINT: test_linalg: use xp.stack not xp.asarray for nested arrays (#363)
1 parent 1f041a3 commit c1e5cda

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed
 

Diff for: ‎array_api_tests/test_linalg.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,11 @@ def test_cross(x1_x2_kw):
206206

207207
def exact_cross(a, b):
208208
assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
209-
return asarray([
209+
return asarray(xp.stack([
210210
a[1]*b[2] - a[2]*b[1],
211211
a[2]*b[0] - a[0]*b[2],
212212
a[0]*b[1] - a[1]*b[0],
213-
], dtype=res.dtype)
213+
]), dtype=res.dtype)
214214

215215
# We don't want to pass in **kw here because that would pass axis to
216216
# cross() on a single stack, but the axis is not meaningful on unstacked
@@ -267,7 +267,7 @@ def true_diag(x_stack, offset=0):
267267
x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)]
268268
else:
269269
x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)]
270-
return asarray(x_stack_diag, dtype=x.dtype)
270+
return asarray(xp.stack(x_stack_diag) if x_stack_diag else [], dtype=x.dtype)
271271

272272
_test_stacks(linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag)
273273

@@ -901,7 +901,9 @@ def true_trace(x_stack, offset=0):
901901
x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)]
902902
else:
903903
x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)]
904-
return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype))
904+
result = xp.asarray(xp.stack(x_stack_diag) if x_stack_diag else [], dtype=x.dtype)
905+
return _array_module.sum(result)
906+
905907

906908
_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)
907909

0 commit comments

Comments
 (0)