diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index f6b7ae25..98d16f27 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -599,3 +599,7 @@ def assert_array_elements( at_expected = expected[idx] msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) assert at_out == at_expected, msg + + +def format_snippet(s: str): + return f"\n{'='*10} FAILING CODE SNIPPET:\n{s}\n{'='*20}\n" diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 4d4af350..007b3179 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -85,24 +85,29 @@ def test_getitem(shape, dtype, data): note(f"{x=}") key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key") - out = x[key] - - ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) - _key = normalize_key(key, shape) - axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape) - ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) - out_zero_sided = any(side == 0 for side in expected_shape) - if not zero_sided and not out_zero_sided: - out_obj = [] - for idx in product(*axes_indices): - val = obj - for i in idx: - val = val[i] - out_obj.append(val) - out_obj = sh.reshape(out_obj, expected_shape) - expected = xp.asarray(out_obj, dtype=dtype) - ph.assert_array_elements("__getitem__", out=out, expected=expected) - + repro_snippet = ph.format_snippet(f"{x!r}[{key!r}]") + + try: + out = x[key] + + ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) + _key = normalize_key(key, shape) + axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape) + ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) + out_zero_sided = any(side == 0 for side in expected_shape) + if not zero_sided and not out_zero_sided: + out_obj = [] + for idx in product(*axes_indices): + val = obj + for i in idx: + val = val[i] + out_obj.append(val) + out_obj = sh.reshape(out_obj, expected_shape) + expected = xp.asarray(out_obj, dtype=dtype) + ph.assert_array_elements("__getitem__", out=out, expected=expected) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @given( diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 84e6f34c..e50b621e 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -160,13 +160,17 @@ def test_finfo(dtype): # np.float64 and np.asarray(1, dtype=np.float64).dtype are different xp.asarray(1, dtype=dtype).dtype, ): - out = xp.finfo(arg) - assert isinstance(out.bits, int) - assert isinstance(out.eps, float) - assert isinstance(out.max, float) - assert isinstance(out.min, float) - assert isinstance(out.smallest_normal, float) - + repro_snippet = ph.format_snippet(f"xp.finfo({arg})") + try: + out = xp.finfo(arg) + assert isinstance(out.bits, int) + assert isinstance(out.eps, float) + assert isinstance(out.max, float) + assert isinstance(out.min, float) + assert isinstance(out.smallest_normal, float) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2022.12") @pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 754b507d..bd0fd351 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -71,50 +71,55 @@ def test_concat(dtypes, base_shape, data): x = data.draw(hh.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}") arrays.append(x) - out = xp.concat(arrays, **kw) + repro_snippet = ph.format_snippet(f"xp.concat({arrays!r}, **kw) with {kw = }") + try: + out = xp.concat(arrays, **kw) - ph.assert_dtype("concat", in_dtype=dtypes, out_dtype=out.dtype) + ph.assert_dtype("concat", in_dtype=dtypes, out_dtype=out.dtype) - shapes = tuple(x.shape for x in arrays) - if _axis is None: - size = sum(math.prod(s) for s in shapes) - shape = (size,) - else: - shape = list(shapes[0]) - for other_shape in shapes[1:]: - shape[_axis] += other_shape[_axis] - shape = tuple(shape) - ph.assert_result_shape("concat", in_shapes=shapes, out_shape=out.shape, expected=shape, kw=kw) - - if _axis is None: - out_indices = (i for i in range(math.prod(out.shape))) - for x_num, x in enumerate(arrays, 1): - for x_idx in sh.ndindex(x.shape): - out_i = next(out_indices) - ph.assert_0d_equals( - "concat", - x_repr=f"x{x_num}[{x_idx}]", - x_val=x[x_idx], - out_repr=f"out[{out_i}]", - out_val=out[out_i], - kw=kw, - ) - else: - out_indices = sh.ndindex(out.shape) - for idx in sh.axis_ndindex(shapes[0], _axis): - f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + shapes = tuple(x.shape for x in arrays) + if _axis is None: + size = sum(math.prod(s) for s in shapes) + shape = (size,) + else: + shape = list(shapes[0]) + for other_shape in shapes[1:]: + shape[_axis] += other_shape[_axis] + shape = tuple(shape) + ph.assert_result_shape("concat", in_shapes=shapes, out_shape=out.shape, expected=shape, kw=kw) + + if _axis is None: + out_indices = (i for i in range(math.prod(out.shape))) for x_num, x in enumerate(arrays, 1): - indexed_x = x[idx] - for x_idx in sh.ndindex(indexed_x.shape): - out_idx = next(out_indices) + for x_idx in sh.ndindex(x.shape): + out_i = next(out_indices) ph.assert_0d_equals( "concat", - x_repr=f"x{x_num}[{f_idx}][{x_idx}]", - x_val=indexed_x[x_idx], - out_repr=f"out[{out_idx}]", - out_val=out[out_idx], + x_repr=f"x{x_num}[{x_idx}]", + x_val=x[x_idx], + out_repr=f"out[{out_i}]", + out_val=out[out_i], kw=kw, ) + else: + out_indices = sh.ndindex(out.shape) + for idx in sh.axis_ndindex(shapes[0], _axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in sh.ndindex(indexed_x.shape): + out_idx = next(out_indices) + ph.assert_0d_equals( + "concat", + x_repr=f"x{x_num}[{f_idx}][{x_idx}]", + x_val=indexed_x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + kw=kw, + ) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -131,19 +136,24 @@ def test_expand_dims(x, axis): xp.expand_dims(x, axis=axis) return - out = xp.expand_dims(x, axis=axis) + repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axis!r})") + try: + out = xp.expand_dims(x, axis=axis) - ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype) - shape = [side for side in x.shape] - index = axis if axis >= 0 else x.ndim + axis + 1 - shape.insert(index, 1) - shape = tuple(shape) - ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape) + shape = [side for side in x.shape] + index = axis if axis >= 0 else x.ndim + axis + 1 + shape.insert(index, 1) + shape = tuple(shape) + ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape) - assert_array_ndindex( - "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape) - ) + assert_array_ndindex( + "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape) + ) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2023.12") @@ -166,30 +176,35 @@ def test_moveaxis(x, data): label="destination" ) - out = xp.moveaxis(x, source, destination) + repro_snippet = ph.format_snippet(f"xp.moveaxis({x!r}, {source!r}, {destination!r})") + try: + out = xp.moveaxis(x, source, destination) - ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype) + _source = sh.normalize_axis(source, x.ndim) + _destination = sh.normalize_axis(destination, x.ndim) - _source = sh.normalize_axis(source, x.ndim) - _destination = sh.normalize_axis(destination, x.ndim) + new_axes = [n for n in range(x.ndim) if n not in _source] - new_axes = [n for n in range(x.ndim) if n not in _source] + for dest, src in sorted(zip(_destination, _source)): + new_axes.insert(dest, src) - for dest, src in sorted(zip(_destination, _source)): - new_axes.insert(dest, src) + expected_shape = tuple(x.shape[i] for i in new_axes) - expected_shape = tuple(x.shape[i] for i in new_axes) + ph.assert_result_shape("moveaxis", in_shapes=[x.shape], + out_shape=out.shape, expected=expected_shape, + kw={"source": source, "destination": destination}) - ph.assert_result_shape("moveaxis", in_shapes=[x.shape], - out_shape=out.shape, expected=expected_shape, - kw={"source": source, "destination": destination}) + indices = list(sh.ndindex(x.shape)) + permuted_indices = [tuple(idx[axis] for axis in new_axes) for idx in indices] + assert_array_ndindex( + "moveaxis", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=permuted_indices + ) + except Exception as exc: + exc.add_note(repro_snippet) + raise - indices = list(sh.ndindex(x.shape)) - permuted_indices = [tuple(idx[axis] for axis in new_axes) for idx in indices] - assert_array_ndindex( - "moveaxis", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=permuted_indices - ) @pytest.mark.unvectorized @given( @@ -215,18 +230,23 @@ def test_squeeze(x, data): xp.squeeze(x, axis) return - out = xp.squeeze(x, axis) + repro_snippet = ph.format_snippet(f"xp.squeeze({x!r}, {axis!r})") + try: + out = xp.squeeze(x, axis) - ph.assert_dtype("squeeze", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_dtype("squeeze", in_dtype=x.dtype, out_dtype=out.dtype) - shape = [] - for i, side in enumerate(x.shape): - if i not in axes: - shape.append(side) - shape = tuple(shape) - ph.assert_result_shape("squeeze", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axis=axis)) + shape = [] + for i, side in enumerate(x.shape): + if i not in axes: + shape.append(side) + shape = tuple(shape) + ph.assert_result_shape("squeeze", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axis=axis)) - assert_array_ndindex("squeeze", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) + assert_array_ndindex("squeeze", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -243,15 +263,20 @@ def test_flip(x, data): ) kw = data.draw(hh.kwargs(axis=axis_strat), label="kw") - out = xp.flip(x, **kw) + repro_snippet = ph.format_snippet(f"xp.flip({x!r}, **kw) with {kw=}") + try: + out = xp.flip(x, **kw) - ph.assert_dtype("flip", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_dtype("flip", in_dtype=x.dtype, out_dtype=out.dtype) - _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - for indices in sh.axes_ndindex(x.shape, _axes): - reverse_indices = indices[::-1] - assert_array_ndindex("flip", x, x_indices=indices, out=out, - out_indices=reverse_indices, kw=kw) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + for indices in sh.axes_ndindex(x.shape, _axes): + reverse_indices = indices[::-1] + assert_array_ndindex("flip", x, x_indices=indices, out=out, + out_indices=reverse_indices, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -267,22 +292,26 @@ def test_flip(x, data): ), ) def test_permute_dims(x, axes): - out = xp.permute_dims(x, axes) - - ph.assert_dtype("permute_dims", in_dtype=x.dtype, out_dtype=out.dtype) + repro_snippet = ph.format_snippet(f"xp.permute_dims({x!r},{axes!r})") + try: + out = xp.permute_dims(x, axes) - shape = [None for _ in range(len(axes))] - for i, dim in enumerate(axes): - side = x.shape[dim] - shape[i] = side - shape = tuple(shape) - ph.assert_result_shape("permute_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axes=axes)) + ph.assert_dtype("permute_dims", in_dtype=x.dtype, out_dtype=out.dtype) - indices = list(sh.ndindex(x.shape)) - permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] - assert_array_ndindex("permute_dims", x, x_indices=indices, out=out, - out_indices=permuted_indices) + shape = [None for _ in range(len(axes))] + for i, dim in enumerate(axes): + side = x.shape[dim] + shape[i] = side + shape = tuple(shape) + ph.assert_result_shape("permute_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axes=axes)) + indices = list(sh.ndindex(x.shape)) + permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] + assert_array_ndindex("permute_dims", x, x_indices=indices, out=out, + out_indices=permuted_indices) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2023.12") @given( @@ -313,37 +342,45 @@ def test_repeat(x, kw, data): assume(n_repititions <= hh.SQRT_MAX_ARRAY_SIZE) - out = xp.repeat(x, repeats, **kw) - ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype) - if axis is None: - expected_shape = (n_repititions,) - else: - expected_shape = list(shape) - expected_shape[axis] = n_repititions - expected_shape = tuple(expected_shape) - ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape) + repro_snippet = ph.format_snippet(f"xp.repeat({x!r},{repeats!r}, **kw) with {kw=}") + try: + out = xp.repeat(x, repeats, **kw) - # Test values + ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype) + if axis is None: + expected_shape = (n_repititions,) + else: + expected_shape = list(shape) + expected_shape[axis] = n_repititions + expected_shape = tuple(expected_shape) + ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape) + + # Test values + + if isinstance(repeats, int): + repeats_array = xp.full(size, repeats, dtype=xp.int32) + else: + repeats_array = repeats + + if kw.get("axis") is None: + x = xp.reshape(x, (-1,)) + axis = 0 + + for idx, in sh.iter_indices(x.shape, skip_axes=axis): + x_slice = x[idx] + out_slice = out[idx] + start = 0 + for i, count in enumerate(repeats_array): + end = start + count + ph.assert_array_elements("repeat", out=out_slice[start:end], + expected=xp.full((count,), x_slice[i], dtype=x.dtype), + kw=kw) + start = end + + except Exception as exc: + exc.add_note(repro_snippet) + raise - if isinstance(repeats, int): - repeats_array = xp.full(size, repeats, dtype=xp.int32) - else: - repeats_array = repeats - - if kw.get("axis") is None: - x = xp.reshape(x, (-1,)) - axis = 0 - - for idx, in sh.iter_indices(x.shape, skip_axes=axis): - x_slice = x[idx] - out_slice = out[idx] - start = 0 - for i, count in enumerate(repeats_array): - end = start + count - ph.assert_array_elements("repeat", out=out_slice[start:end], - expected=xp.full((count,), x_slice[i], dtype=x.dtype), - kw=kw) - start = end reshape_shape = st.shared(hh.shapes(), key="reshape_shape") @@ -353,19 +390,24 @@ def test_repeat(x, kw, data): shape=hh.reshape_shapes(reshape_shape), ) def test_reshape(x, shape): - out = xp.reshape(x, shape) + repro_snippet = ph.format_snippet(f"xp.reshape({x!r},{shape!r})") + try: + out = xp.reshape(x, shape) - ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype) - _shape = list(shape) - if any(side == -1 for side in shape): - size = math.prod(x.shape) - rsize = math.prod(shape) * -1 - _shape[shape.index(-1)] = size / rsize - _shape = tuple(_shape) - ph.assert_result_shape("reshape", in_shapes=[x.shape], out_shape=out.shape, expected=_shape, kw=dict(shape=shape)) + _shape = list(shape) + if any(side == -1 for side in shape): + size = math.prod(x.shape) + rsize = math.prod(shape) * -1 + _shape[shape.index(-1)] = size / rsize + _shape = tuple(_shape) + ph.assert_result_shape("reshape", in_shapes=[x.shape], out_shape=out.shape, expected=_shape, kw=dict(shape=shape)) - assert_array_ndindex("reshape", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) + assert_array_ndindex("reshape", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) + except Exception as exc: + exc.add_note(repro_snippet) + raise def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator[Shape]: @@ -396,25 +438,30 @@ def test_roll(x, data): kw_strat = hh.kwargs(axis=axis_strat) kw = data.draw(kw_strat, label="kw") - out = xp.roll(x, shift, **kw) + repro_snippet = ph.format_snippet(f"xp.roll({x!r},{shift!r}, **kw) with {kw=}") + try: + out = xp.roll(x, shift, **kw) - kw = {"shift": shift, **kw} # for error messages + kw = {"shift": shift, **kw} # for error messages - ph.assert_dtype("roll", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_dtype("roll", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_result_shape("roll", in_shapes=[x.shape], out_shape=out.shape, kw=kw) + ph.assert_result_shape("roll", in_shapes=[x.shape], out_shape=out.shape, kw=kw) - if kw.get("axis", None) is None: - assert isinstance(shift, int) # sanity check - indices = list(sh.ndindex(x.shape)) - shifted_indices = deque(indices) - shifted_indices.rotate(-shift) - assert_array_ndindex("roll", x, x_indices=indices, out=out, out_indices=shifted_indices, kw=kw) - else: - shifts = (shift,) if isinstance(shift, int) else shift - axes = sh.normalize_axis(kw["axis"], x.ndim) - shifted_indices = roll_ndindex(x.shape, shifts, axes) - assert_array_ndindex("roll", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=shifted_indices, kw=kw) + if kw.get("axis", None) is None: + assert isinstance(shift, int) # sanity check + indices = list(sh.ndindex(x.shape)) + shifted_indices = deque(indices) + shifted_indices.rotate(-shift) + assert_array_ndindex("roll", x, x_indices=indices, out=out, out_indices=shifted_indices, kw=kw) + else: + shifts = (shift,) if isinstance(shift, int) else shift + axes = sh.normalize_axis(kw["axis"], x.ndim) + shifted_indices = roll_ndindex(x.shape, shifts, axes) + assert_array_ndindex("roll", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=shifted_indices, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -434,34 +481,39 @@ def test_stack(shape, dtypes, kw, data): x = data.draw(hh.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) - out = xp.stack(arrays, **kw) + repro_snippet = ph.format_snippet(f"xp.stack({arrays!r}, **kw) with {kw=}") + try: + out = xp.stack(arrays, **kw) - ph.assert_dtype("stack", in_dtype=dtypes, out_dtype=out.dtype) + ph.assert_dtype("stack", in_dtype=dtypes, out_dtype=out.dtype) - axis = kw.get("axis", 0) - _axis = axis if axis >= 0 else len(shape) + axis + 1 - _shape = list(shape) - _shape.insert(_axis, len(arrays)) - _shape = tuple(_shape) - ph.assert_result_shape( - "stack", in_shapes=tuple(x.shape for x in arrays), out_shape=out.shape, expected=_shape, kw=kw - ) + axis = kw.get("axis", 0) + _axis = axis if axis >= 0 else len(shape) + axis + 1 + _shape = list(shape) + _shape.insert(_axis, len(arrays)) + _shape = tuple(_shape) + ph.assert_result_shape( + "stack", in_shapes=tuple(x.shape for x in arrays), out_shape=out.shape, expected=_shape, kw=kw + ) - out_indices = sh.ndindex(out.shape) - for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis): - f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) - for x_num, x in enumerate(arrays, 1): - indexed_x = x[idx] - for x_idx in sh.ndindex(indexed_x.shape): - out_idx = next(out_indices) - ph.assert_0d_equals( - "stack", - x_repr=f"x{x_num}[{f_idx}][{x_idx}]", - x_val=indexed_x[x_idx], - out_repr=f"out[{out_idx}]", - out_val=out[out_idx], - kw=kw, - ) + out_indices = sh.ndindex(out.shape) + for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in sh.ndindex(indexed_x.shape): + out_idx = next(out_indices) + ph.assert_0d_equals( + "stack", + x_repr=f"x{x_num}[{f_idx}][{x_idx}]", + x_val=indexed_x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + kw=kw, + ) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2023.12") @@ -471,43 +523,53 @@ def test_tile(x, data): st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple), label="repetitions" ) - out = xp.tile(x, repetitions) - ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype) - # TODO: values testing - - # shape check; the notation is from the Array API docs - N, M = len(x.shape), len(repetitions) - if N > M: - S = x.shape - R = (1,)*(N - M) + repetitions - else: - S = (1,)*(M - N) + x.shape - R = repetitions - - assert out.shape == tuple(r*s for r, s in zip(R, S)) + repro_snippet = ph.format_snippet(f"xp.tile({x!r}, {repetitions!r})") + try: + out = xp.tile(x, repetitions) + ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype) + # TODO: values testing + + # shape check; the notation is from the Array API docs + N, M = len(x.shape), len(repetitions) + if N > M: + S = x.shape + R = (1,)*(N - M) + repetitions + else: + S = (1,)*(M - N) + x.shape + R = repetitions + assert out.shape == tuple(r*s for r, s in zip(R, S)) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), data=st.data()) def test_unstack(x, data): axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis") kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw") - out = xp.unstack(x, **kw) - - assert isinstance(out, tuple) - assert len(out) == x.shape[axis] - expected_shape = list(x.shape) - expected_shape.pop(axis) - expected_shape = tuple(expected_shape) - for i in range(x.shape[axis]): - arr = out[i] - ph.assert_result_shape("unstack", in_shapes=[x.shape], - out_shape=arr.shape, expected=expected_shape, - kw=kw, repr_name=f"out[{i}].shape") - - ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype, - repr_name=f"out[{i}].dtype") - - idx = [slice(None)] * x.ndim - idx[axis] = i - ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]") + + repro_snippet = ph.format_snippet(f"xp.unstack({x!r}, **kw) with {kw=}") + try: + out = xp.unstack(x, **kw) + + assert isinstance(out, tuple) + assert len(out) == x.shape[axis] + expected_shape = list(x.shape) + expected_shape.pop(axis) + expected_shape = tuple(expected_shape) + for i in range(x.shape[axis]): + arr = out[i] + ph.assert_result_shape("unstack", in_shapes=[x.shape], + out_shape=arr.shape, expected=expected_shape, + kw=kw, repr_name=f"out[{i}].shape") + + ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype, + repr_name=f"out[{i}].dtype") + + idx = [slice(None)] * x.ndim + idx[axis] = i + ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]") + except Exception as exc: + exc.add_note(repro_snippet) + raise