@@ -51,43 +51,47 @@ def test_argsort(x, data):
5151 label = "kw" ,
5252 )
5353
54- out = xp .argsort (x , ** kw )
55-
56- ph .assert_default_index ("argsort" , out .dtype )
57- ph .assert_shape ("argsort" , out_shape = out .shape , expected = x .shape , kw = kw )
58- axis = kw .get ("axis" , - 1 )
59- axes = sh .normalize_axis (axis , x .ndim )
60- scalar_type = dh .get_scalar_type (x .dtype )
61- for indices in sh .axes_ndindex (x .shape , axes ):
62- elements = [scalar_type (x [idx ]) for idx in indices ]
63- orders = list (range (len (elements )))
64- sorders = sorted (
65- orders , key = elements .__getitem__ , reverse = kw .get ("descending" , False )
66- )
67- if kw .get ("stable" , True ):
68- for idx , o in zip (indices , sorders ):
69- ph .assert_scalar_equals ("argsort" , type_ = int , idx = idx , out = int (out [idx ]), expected = o , kw = kw )
70- else :
71- idx_elements = dict (zip (indices , elements ))
72- idx_orders = dict (zip (indices , orders ))
73- element_orders = {}
74- for e in set (elements ):
75- element_orders [e ] = [
76- idx_orders [idx ] for idx in indices if idx_elements [idx ] == e
77- ]
78- selements = [elements [o ] for o in sorders ]
79- for idx , e in zip (indices , selements ):
80- expected_orders = element_orders [e ]
81- out_o = int (out [idx ])
82- if len (expected_orders ) == 1 :
83- ph .assert_scalar_equals (
84- "argsort" , type_ = int , idx = idx , out = out_o , expected = expected_orders [0 ], kw = kw
85- )
86- else :
87- assert_scalar_in_set (
88- "argsort" , idx = idx , out = out_o , set_ = set (expected_orders ), kw = kw
89- )
54+ repro_snippet = ph .format_snippet (f"xp.argsort({ x !r} , **kw) with { kw = } " )
55+ try :
56+ out = xp .argsort (x , ** kw )
9057
58+ ph .assert_default_index ("argsort" , out .dtype )
59+ ph .assert_shape ("argsort" , out_shape = out .shape , expected = x .shape , kw = kw )
60+ axis = kw .get ("axis" , - 1 )
61+ axes = sh .normalize_axis (axis , x .ndim )
62+ scalar_type = dh .get_scalar_type (x .dtype )
63+ for indices in sh .axes_ndindex (x .shape , axes ):
64+ elements = [scalar_type (x [idx ]) for idx in indices ]
65+ orders = list (range (len (elements )))
66+ sorders = sorted (
67+ orders , key = elements .__getitem__ , reverse = kw .get ("descending" , False )
68+ )
69+ if kw .get ("stable" , True ):
70+ for idx , o in zip (indices , sorders ):
71+ ph .assert_scalar_equals ("argsort" , type_ = int , idx = idx , out = int (out [idx ]), expected = o , kw = kw )
72+ else :
73+ idx_elements = dict (zip (indices , elements ))
74+ idx_orders = dict (zip (indices , orders ))
75+ element_orders = {}
76+ for e in set (elements ):
77+ element_orders [e ] = [
78+ idx_orders [idx ] for idx in indices if idx_elements [idx ] == e
79+ ]
80+ selements = [elements [o ] for o in sorders ]
81+ for idx , e in zip (indices , selements ):
82+ expected_orders = element_orders [e ]
83+ out_o = int (out [idx ])
84+ if len (expected_orders ) == 1 :
85+ ph .assert_scalar_equals (
86+ "argsort" , type_ = int , idx = idx , out = out_o , expected = expected_orders [0 ], kw = kw
87+ )
88+ else :
89+ assert_scalar_in_set (
90+ "argsort" , idx = idx , out = out_o , set_ = set (expected_orders ), kw = kw
91+ )
92+ except Exception as exc :
93+ exc .add_note (repro_snippet )
94+ raise
9195
9296@pytest .mark .unvectorized
9397# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@@ -112,27 +116,32 @@ def test_sort(x, data):
112116 label = "kw" ,
113117 )
114118
115- out = xp .sort (x , ** kw )
119+ repro_snippet = ph .format_snippet (f"xp.sort({ x !r} , **kw) with { kw = } " )
120+ try :
121+ out = xp .sort (x , ** kw )
116122
117- ph .assert_dtype ("sort" , out_dtype = out .dtype , in_dtype = x .dtype )
118- ph .assert_shape ("sort" , out_shape = out .shape , expected = x .shape , kw = kw )
119- axis = kw .get ("axis" , - 1 )
120- axes = sh .normalize_axis (axis , x .ndim )
121- scalar_type = dh .get_scalar_type (x .dtype )
122- for indices in sh .axes_ndindex (x .shape , axes ):
123- elements = [scalar_type (x [idx ]) for idx in indices ]
124- size = len (elements )
125- orders = sorted (
126- range (size ), key = elements .__getitem__ , reverse = kw .get ("descending" , False )
127- )
128- for out_idx , o in zip (indices , orders ):
129- x_idx = indices [o ]
130- # TODO: error message when unstable should not imply just one idx
131- ph .assert_0d_equals (
132- "sort" ,
133- x_repr = f"x[{ x_idx } ]" ,
134- x_val = x [x_idx ],
135- out_repr = f"out[{ out_idx } ]" ,
136- out_val = out [out_idx ],
137- kw = kw ,
123+ ph .assert_dtype ("sort" , out_dtype = out .dtype , in_dtype = x .dtype )
124+ ph .assert_shape ("sort" , out_shape = out .shape , expected = x .shape , kw = kw )
125+ axis = kw .get ("axis" , - 1 )
126+ axes = sh .normalize_axis (axis , x .ndim )
127+ scalar_type = dh .get_scalar_type (x .dtype )
128+ for indices in sh .axes_ndindex (x .shape , axes ):
129+ elements = [scalar_type (x [idx ]) for idx in indices ]
130+ size = len (elements )
131+ orders = sorted (
132+ range (size ), key = elements .__getitem__ , reverse = kw .get ("descending" , False )
138133 )
134+ for out_idx , o in zip (indices , orders ):
135+ x_idx = indices [o ]
136+ # TODO: error message when unstable should not imply just one idx
137+ ph .assert_0d_equals (
138+ "sort" ,
139+ x_repr = f"x[{ x_idx } ]" ,
140+ x_val = x [x_idx ],
141+ out_repr = f"out[{ out_idx } ]" ,
142+ out_val = out [out_idx ],
143+ kw = kw ,
144+ )
145+ except Exception as exc :
146+ exc .add_note (repro_snippet )
147+ raise
0 commit comments