Skip to content

Commit bd42e84

Browse files
committed
ENH: add "repro snippets" to test_sorting_functions.py
1 parent e807ffe commit bd42e84

File tree

1 file changed

+67
-58
lines changed

1 file changed

+67
-58
lines changed

array_api_tests/test_sorting_functions.py

Lines changed: 67 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -51,43 +51,47 @@ def test_argsort(x, data):
5151
label="kw",
5252
)
5353

54-
out = xp.argsort(x, **kw)
55-
56-
ph.assert_default_index("argsort", out.dtype)
57-
ph.assert_shape("argsort", out_shape=out.shape, expected=x.shape, kw=kw)
58-
axis = kw.get("axis", -1)
59-
axes = sh.normalize_axis(axis, x.ndim)
60-
scalar_type = dh.get_scalar_type(x.dtype)
61-
for indices in sh.axes_ndindex(x.shape, axes):
62-
elements = [scalar_type(x[idx]) for idx in indices]
63-
orders = list(range(len(elements)))
64-
sorders = sorted(
65-
orders, key=elements.__getitem__, reverse=kw.get("descending", False)
66-
)
67-
if kw.get("stable", True):
68-
for idx, o in zip(indices, sorders):
69-
ph.assert_scalar_equals("argsort", type_=int, idx=idx, out=int(out[idx]), expected=o, kw=kw)
70-
else:
71-
idx_elements = dict(zip(indices, elements))
72-
idx_orders = dict(zip(indices, orders))
73-
element_orders = {}
74-
for e in set(elements):
75-
element_orders[e] = [
76-
idx_orders[idx] for idx in indices if idx_elements[idx] == e
77-
]
78-
selements = [elements[o] for o in sorders]
79-
for idx, e in zip(indices, selements):
80-
expected_orders = element_orders[e]
81-
out_o = int(out[idx])
82-
if len(expected_orders) == 1:
83-
ph.assert_scalar_equals(
84-
"argsort", type_=int, idx=idx, out=out_o, expected=expected_orders[0], kw=kw
85-
)
86-
else:
87-
assert_scalar_in_set(
88-
"argsort", idx=idx, out=out_o, set_=set(expected_orders), kw=kw
89-
)
54+
repro_snippet = ph.format_snippet(f"xp.argsort({x!r}, **kw) with {kw = }")
55+
try:
56+
out = xp.argsort(x, **kw)
9057

58+
ph.assert_default_index("argsort", out.dtype)
59+
ph.assert_shape("argsort", out_shape=out.shape, expected=x.shape, kw=kw)
60+
axis = kw.get("axis", -1)
61+
axes = sh.normalize_axis(axis, x.ndim)
62+
scalar_type = dh.get_scalar_type(x.dtype)
63+
for indices in sh.axes_ndindex(x.shape, axes):
64+
elements = [scalar_type(x[idx]) for idx in indices]
65+
orders = list(range(len(elements)))
66+
sorders = sorted(
67+
orders, key=elements.__getitem__, reverse=kw.get("descending", False)
68+
)
69+
if kw.get("stable", True):
70+
for idx, o in zip(indices, sorders):
71+
ph.assert_scalar_equals("argsort", type_=int, idx=idx, out=int(out[idx]), expected=o, kw=kw)
72+
else:
73+
idx_elements = dict(zip(indices, elements))
74+
idx_orders = dict(zip(indices, orders))
75+
element_orders = {}
76+
for e in set(elements):
77+
element_orders[e] = [
78+
idx_orders[idx] for idx in indices if idx_elements[idx] == e
79+
]
80+
selements = [elements[o] for o in sorders]
81+
for idx, e in zip(indices, selements):
82+
expected_orders = element_orders[e]
83+
out_o = int(out[idx])
84+
if len(expected_orders) == 1:
85+
ph.assert_scalar_equals(
86+
"argsort", type_=int, idx=idx, out=out_o, expected=expected_orders[0], kw=kw
87+
)
88+
else:
89+
assert_scalar_in_set(
90+
"argsort", idx=idx, out=out_o, set_=set(expected_orders), kw=kw
91+
)
92+
except Exception as exc:
93+
exc.add_note(repro_snippet)
94+
raise
9195

9296
@pytest.mark.unvectorized
9397
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@@ -112,27 +116,32 @@ def test_sort(x, data):
112116
label="kw",
113117
)
114118

115-
out = xp.sort(x, **kw)
119+
repro_snippet = ph.format_snippet(f"xp.sort({x!r}, **kw) with {kw = }")
120+
try:
121+
out = xp.sort(x, **kw)
116122

117-
ph.assert_dtype("sort", out_dtype=out.dtype, in_dtype=x.dtype)
118-
ph.assert_shape("sort", out_shape=out.shape, expected=x.shape, kw=kw)
119-
axis = kw.get("axis", -1)
120-
axes = sh.normalize_axis(axis, x.ndim)
121-
scalar_type = dh.get_scalar_type(x.dtype)
122-
for indices in sh.axes_ndindex(x.shape, axes):
123-
elements = [scalar_type(x[idx]) for idx in indices]
124-
size = len(elements)
125-
orders = sorted(
126-
range(size), key=elements.__getitem__, reverse=kw.get("descending", False)
127-
)
128-
for out_idx, o in zip(indices, orders):
129-
x_idx = indices[o]
130-
# TODO: error message when unstable should not imply just one idx
131-
ph.assert_0d_equals(
132-
"sort",
133-
x_repr=f"x[{x_idx}]",
134-
x_val=x[x_idx],
135-
out_repr=f"out[{out_idx}]",
136-
out_val=out[out_idx],
137-
kw=kw,
123+
ph.assert_dtype("sort", out_dtype=out.dtype, in_dtype=x.dtype)
124+
ph.assert_shape("sort", out_shape=out.shape, expected=x.shape, kw=kw)
125+
axis = kw.get("axis", -1)
126+
axes = sh.normalize_axis(axis, x.ndim)
127+
scalar_type = dh.get_scalar_type(x.dtype)
128+
for indices in sh.axes_ndindex(x.shape, axes):
129+
elements = [scalar_type(x[idx]) for idx in indices]
130+
size = len(elements)
131+
orders = sorted(
132+
range(size), key=elements.__getitem__, reverse=kw.get("descending", False)
138133
)
134+
for out_idx, o in zip(indices, orders):
135+
x_idx = indices[o]
136+
# TODO: error message when unstable should not imply just one idx
137+
ph.assert_0d_equals(
138+
"sort",
139+
x_repr=f"x[{x_idx}]",
140+
x_val=x[x_idx],
141+
out_repr=f"out[{out_idx}]",
142+
out_val=out[out_idx],
143+
kw=kw,
144+
)
145+
except Exception as exc:
146+
exc.add_note(repro_snippet)
147+
raise

0 commit comments

Comments
 (0)