diff --git a/tests/test_constants.py b/tests/test_constants.py index 855d94a..1835a9c 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -174,7 +174,7 @@ def test_constant_dtype(dtype_index: int) -> None: ) def test_constant_complex_dtype(dtype_index: int) -> None: """Test if constant_complex creates an array with the correct dtype.""" - if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()): + if not pytest.skip() dtype = dtypes.c_api_value_to_dtype(dtype_index) @@ -219,3 +219,4 @@ def test_constant_ulong_dtype() -> None: assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() + \ No newline at end of file diff --git a/tests/test_utilities.py b/tests/test_utilities.py new file mode 100644 index 0000000..3e45f78 --- /dev/null +++ b/tests/test_utilities.py @@ -0,0 +1,13 @@ +import pytest + +import arrayfire_wrapper.lib as wrapper +from arrayfire_wrapper.dtypes import Dtype, c64, f16, f64 + + +def check_type_supported(dtype: Dtype) -> None: + """Checks to see if the specified type is supported by the current system""" + if dtype in [f64, c64] and not wrapper.get_dbl_support(): + pytest.skip("Device does not support double types") + + if dtype == f16 and not wrapper.get_half_support(): + pytest.skip("Device does not support half types.")