Skip to content

Commit b1cec12

Browse files
authored
Tighten tests on iinfo/finfo (#358)
Reviewed at #358
1 parent c48410f commit b1cec12

File tree

2 files changed

+55
-26
lines changed

2 files changed

+55
-26
lines changed

Diff for: array-api-strict-skips.txt

+5
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,8 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity
2727
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
2828
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
2929

30+
# FIXME needs array-api-strict >=2.3.2
31+
array_api_tests/test_data_type_functions.py::test_finfo
32+
array_api_tests/test_data_type_functions.py::test_finfo_dtype
33+
array_api_tests/test_data_type_functions.py::test_iinfo
34+
array_api_tests/test_data_type_functions.py::test_iinfo_dtype

Diff for: array_api_tests/test_data_type_functions.py

+50-26
Original file line numberDiff line numberDiff line change
@@ -147,38 +147,62 @@ def test_can_cast(_from, to):
147147
assert out == expected, f"{out=}, but should be {expected} {f_func}"
148148

149149

150-
@pytest.mark.parametrize("dtype", dh.real_float_dtypes)
150+
@pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes)
151151
def test_finfo(dtype):
152+
for arg in (
153+
dtype,
154+
xp.asarray(1, dtype=dtype),
155+
# np.float64 and np.asarray(1, dtype=np.float64).dtype are different
156+
xp.asarray(1, dtype=dtype).dtype,
157+
):
158+
out = xp.finfo(arg)
159+
assert isinstance(out.bits, int)
160+
assert isinstance(out.eps, float)
161+
assert isinstance(out.max, float)
162+
assert isinstance(out.min, float)
163+
assert isinstance(out.smallest_normal, float)
164+
165+
166+
@pytest.mark.min_version("2022.12")
167+
@pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes)
168+
def test_finfo_dtype(dtype):
152169
out = xp.finfo(dtype)
153-
f_func = f"[finfo({dh.dtype_to_name[dtype]})]"
154-
for attr, stype in [
155-
("bits", int),
156-
("eps", float),
157-
("max", float),
158-
("min", float),
159-
("smallest_normal", float),
160-
]:
161-
assert hasattr(out, attr), f"out has no attribute '{attr}' {f_func}"
162-
value = getattr(out, attr)
163-
assert isinstance(
164-
value, stype
165-
), f"type(out.{attr})={type(value)!r}, but should be {stype.__name__} {f_func}"
166-
assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}"
167-
# TODO: test values
168170

171+
if dtype == xp.complex64:
172+
assert out.dtype == xp.float32
173+
elif dtype == xp.complex128:
174+
assert out.dtype == xp.float64
175+
else:
176+
assert out.dtype == dtype
177+
178+
# Guard vs. numpy.dtype.__eq__ lax comparison
179+
assert not isinstance(out.dtype, str)
180+
assert out.dtype is not float
181+
assert out.dtype is not complex
169182

170-
@pytest.mark.parametrize("dtype", dh.int_dtypes)
183+
184+
@pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes)
171185
def test_iinfo(dtype):
186+
for arg in (
187+
dtype,
188+
xp.asarray(1, dtype=dtype),
189+
# np.int64 and np.asarray(1, dtype=np.int64).dtype are different
190+
xp.asarray(1, dtype=dtype).dtype,
191+
):
192+
out = xp.iinfo(arg)
193+
assert isinstance(out.bits, int)
194+
assert isinstance(out.max, int)
195+
assert isinstance(out.min, int)
196+
197+
198+
@pytest.mark.min_version("2022.12")
199+
@pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes)
200+
def test_iinfo_dtype(dtype):
172201
out = xp.iinfo(dtype)
173-
f_func = f"[iinfo({dh.dtype_to_name[dtype]})]"
174-
for attr in ["bits", "max", "min"]:
175-
assert hasattr(out, attr), f"out has no attribute '{attr}' {f_func}"
176-
value = getattr(out, attr)
177-
assert isinstance(
178-
value, int
179-
), f"type(out.{attr})={type(value)!r}, but should be int {f_func}"
180-
assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}"
181-
# TODO: test values
202+
assert out.dtype == dtype
203+
# Guard vs. numpy.dtype.__eq__ lax comparison
204+
assert not isinstance(out.dtype, str)
205+
assert out.dtype is not int
182206

183207

184208
def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]:

0 commit comments

Comments
 (0)