From 25fbfd011e283c07b77e89751a5de2cd155a2999 Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Wed, 13 Mar 2024 10:35:41 -0400 Subject: [PATCH 1/4] utility functions for tests --- tests/utility_functions.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 tests/utility_functions.py diff --git a/tests/utility_functions.py b/tests/utility_functions.py new file mode 100644 index 0000000..3e45f78 --- /dev/null +++ b/tests/utility_functions.py @@ -0,0 +1,13 @@ +import pytest + +import arrayfire_wrapper.lib as wrapper +from arrayfire_wrapper.dtypes import Dtype, c64, f16, f64 + + +def check_type_supported(dtype: Dtype) -> None: + """Checks to see if the specified type is supported by the current system""" + if dtype in [f64, c64] and not wrapper.get_dbl_support(): + pytest.skip("Device does not support double types") + + if dtype == f16 and not wrapper.get_half_support(): + pytest.skip("Device does not support half types.") From 2fbaace5e3775703e0c97471017cd0bba5df1a77 Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Tue, 5 Mar 2024 09:50:46 -0500 Subject: [PATCH 2/4] added unit tests for range function --- tests/test_range.py | 61 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/test_range.py diff --git a/tests/test_range.py b/tests/test_range.py new file mode 100644 index 0000000..1571698 --- /dev/null +++ b/tests/test_range.py @@ -0,0 +1,61 @@ +import random + +import pytest + +import arrayfire_wrapper.dtypes as dtypes +import arrayfire_wrapper.lib as wrapper + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_range_shape(shape: tuple) -> None: + """Test if the range function output an AFArray with the correct shape""" + dim = 2 + dtype = dtypes.s16 + + result = wrapper.range(shape, dim, dtype) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +def test_range_invalid_shape() -> None: + """Test if range function correctly handles an invalid shape""" + with pytest.raises(TypeError): + shape = ( + random.randint(1, 10), + random.randint(1, 10), + random.randint(1, 10), + random.randint(1, 10), + random.randint(1, 10), + ) + dim = 2 + dtype = dtypes.s16 + + wrapper.range(shape, dim, dtype) + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_range_invalid_dim(shape: tuple) -> None: + """Test if the range function can properly handle and invalid dimension given""" + with pytest.raises(RuntimeError): + dim = random.randint(4, 10) + dtype = dtypes.s16 + + wrapper.range(shape, dim, dtype) From 1bc0f3a84cf3733040a0a0f6a2bc6ce790dc933e Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Tue, 12 Mar 2024 10:48:35 -0400 Subject: [PATCH 3/4] Readability changes to cosntants tests --- tests/test_constants.py | 69 +++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/tests/test_constants.py b/tests/test_constants.py index 855d94a..929c8ed 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -2,8 +2,23 @@ import pytest -import arrayfire_wrapper.dtypes as dtypes import arrayfire_wrapper.lib as wrapper +from arrayfire_wrapper.dtypes import ( + Dtype, + c32, + c64, + c_api_value_to_dtype, + f16, + f32, + f64, + s16, + s32, + s64, + u8, + u16, + u32, + u64, +) invalid_shape = ( random.randint(1, 10), @@ -14,6 +29,9 @@ ) +types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64] + + @pytest.mark.parametrize( "shape", [ @@ -27,7 +45,7 @@ def test_constant_shape(shape: tuple) -> None: """Test if constant creates an array with the correct shape.""" number = 5.0 - dtype = dtypes.s16 + dtype = s16 result = wrapper.constant(number, shape, dtype) @@ -46,9 +64,9 @@ def test_constant_shape(shape: tuple) -> None: ) def test_constant_complex_shape(shape: tuple) -> None: """Test if constant_complex creates an array with the correct shape.""" - dtype = dtypes.c32 + dtype = c32 - dtype = dtypes.c32 + dtype = c32 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -71,7 +89,7 @@ def test_constant_complex_shape(shape: tuple) -> None: ) def test_constant_long_shape(shape: tuple) -> None: """Test if constant_long creates an array with the correct shape.""" - dtype = dtypes.s64 + dtype = s64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -93,7 +111,7 @@ def test_constant_long_shape(shape: tuple) -> None: ) def test_constant_ulong_shape(shape: tuple) -> None: """Test if constant_ulong creates an array with the correct shape.""" - dtype = dtypes.u64 + dtype = u64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -109,7 +127,7 @@ def test_constant_shape_invalid() -> None: """Test if constant handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): number = 5.0 - dtype = dtypes.s16 + dtype = s16 wrapper.constant(number, invalid_shape, dtype) @@ -117,7 +135,7 @@ def test_constant_shape_invalid() -> None: def test_constant_complex_shape_invalid() -> None: """Test if constant_complex handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): - dtype = dtypes.c32 + dtype = c32 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -128,7 +146,7 @@ def test_constant_complex_shape_invalid() -> None: def test_constant_long_shape_invalid() -> None: """Test if constant_long handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): - dtype = dtypes.s64 + dtype = s64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -139,7 +157,7 @@ def test_constant_long_shape_invalid() -> None: def test_constant_ulong_shape_invalid() -> None: """Test if constant_ulong handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): - dtype = dtypes.u64 + dtype = u64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -148,50 +166,47 @@ def test_constant_ulong_shape_invalid() -> None: @pytest.mark.parametrize( - "dtype_index", - [i for i in range(13)], + "dtype", + types, ) -def test_constant_dtype(dtype_index: int) -> None: +def test_constant_dtype(dtype: Dtype) -> None: """Test if constant creates an array with the correct dtype.""" - if dtype_index in [1, 3] or (dtype_index == 2 and not wrapper.get_dbl_support()): + if dtype in [c32, c64] or (dtype == f64 and not wrapper.get_dbl_support()): pytest.skip() - dtype = dtypes.c_api_value_to_dtype(dtype_index) - rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) shape = (2, 2) if isinstance(value, (int, float)): result = wrapper.constant(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() @pytest.mark.parametrize( - "dtype_index", - [i for i in range(13)], + "dtype", + types, ) -def test_constant_complex_dtype(dtype_index: int) -> None: +def test_constant_complex_dtype(dtype: Dtype) -> None: """Test if constant_complex creates an array with the correct dtype.""" - if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()): + if dtype not in [c32, c64] or (dtype == c64 and not wrapper.get_dbl_support()): pytest.skip() - dtype = dtypes.c_api_value_to_dtype(dtype_index) rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) shape = (2, 2) if isinstance(value, (int, float, complex)): result = wrapper.constant_complex(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() def test_constant_long_dtype() -> None: """Test if constant_long creates an array with the correct dtype.""" - dtype = dtypes.s64 + dtype = s64 rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) @@ -200,14 +215,14 @@ def test_constant_long_dtype() -> None: if isinstance(value, (int, float)): result = wrapper.constant_long(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() def test_constant_ulong_dtype() -> None: """Test if constant_ulong creates an array with the correct dtype.""" - dtype = dtypes.u64 + dtype = u64 rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) @@ -216,6 +231,6 @@ def test_constant_ulong_dtype() -> None: if isinstance(value, (int, float)): result = wrapper.constant_ulong(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() From bc96cbc734e589a127d3e64428d810bde3628241 Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Tue, 12 Mar 2024 11:42:09 -0400 Subject: [PATCH 4/4] readability changes pt.2 --- tests/test_constants.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/test_constants.py b/tests/test_constants.py index 929c8ed..e513a0b 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -29,7 +29,7 @@ ) -types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64] +all_types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64] @pytest.mark.parametrize( @@ -66,7 +66,6 @@ def test_constant_complex_shape(shape: tuple) -> None: """Test if constant_complex creates an array with the correct shape.""" dtype = c32 - dtype = c32 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -167,11 +166,11 @@ def test_constant_ulong_shape_invalid() -> None: @pytest.mark.parametrize( "dtype", - types, + all_types, ) def test_constant_dtype(dtype: Dtype) -> None: """Test if constant creates an array with the correct dtype.""" - if dtype in [c32, c64] or (dtype == f64 and not wrapper.get_dbl_support()): + if is_cmplx_type(dtype) or not is_system_supported(dtype): pytest.skip() rand_array = wrapper.randu((1, 1), dtype) @@ -186,11 +185,11 @@ def test_constant_dtype(dtype: Dtype) -> None: @pytest.mark.parametrize( "dtype", - types, + all_types, ) def test_constant_complex_dtype(dtype: Dtype) -> None: """Test if constant_complex creates an array with the correct dtype.""" - if dtype not in [c32, c64] or (dtype == c64 and not wrapper.get_dbl_support()): + if not is_cmplx_type(dtype) or not is_system_supported(dtype): pytest.skip() rand_array = wrapper.randu((1, 1), dtype) @@ -234,3 +233,14 @@ def test_constant_ulong_dtype() -> None: assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() + + +def is_cmplx_type(dtype: Dtype) -> bool: + return dtype == c32 or dtype == c64 + + +def is_system_supported(dtype: Dtype) -> bool: + if dtype in [f64, c64] and not wrapper.get_dbl_support(): + return False + + return True