From 5a48ff55b95630f620b14f694d214b2ce158c3af Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Mon, 4 Mar 2024 11:21:21 -0500 Subject: [PATCH 1/3] added random tests --- tests/test_random.py | 119 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 tests/test_random.py diff --git a/tests/test_random.py b/tests/test_random.py new file mode 100644 index 0000000..f1c4063 --- /dev/null +++ b/tests/test_random.py @@ -0,0 +1,119 @@ +import random + +import pytest + +import arrayfire_wrapper.dtypes as dtypes +import arrayfire_wrapper.lib as wrapper + +invalid_shape = ( + random.randint(1, 10), + random.randint(1, 10), + random.randint(1, 10), + random.randint(1, 10), + random.randint(1, 10), +) + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_randu_shape(shape: tuple) -> None: + """Test if randu function creates an array with the correct shape.""" + dtype = dtypes.s16 + + result = wrapper.randu(shape, dtype) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_random_uniform_shape(shape: tuple) -> None: + """Test if rand uniform function creates an array with the correct shape.""" + dtype = dtypes.s16 + engine = wrapper.create_random_engine(100, 10) + + result = wrapper.random_uniform(shape, dtype, engine) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_randn_shape(shape: tuple) -> None: + """Test if randn function creates an array with the correct shape.""" + dtype = dtypes.f32 + + result = wrapper.randn(shape, dtype) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_random_normal_shape(shape: tuple) -> None: + """Test if random normal function creates an array with the correct shape.""" + dtype = dtypes.f32 + engine = wrapper.create_random_engine(100, 10) + + result = wrapper.random_normal(shape, dtype, engine) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +@pytest.mark.parametrize( + "engine_index", + [100, 200, 300], +) +def test_create_random_engine(engine_index: int) -> None: + engine = wrapper.create_random_engine(engine_index, 10) + + engine_type = wrapper.random_engine_get_type(engine) + + assert engine_type == engine_index + + +@pytest.mark.parametrize( + "invalid_index", + [random.randint(301, 600), random.randint(301, 600), random.randint(301, 600)], +) +def test_invalid_random_engine(invalid_index: int) -> None: + "Test if invalid engine types are properly handled" + with pytest.raises(RuntimeError): + + invalid_engine = wrapper.create_random_engine(invalid_index, 10) + + engine_type = wrapper.random_engine_get_type(invalid_engine) + + assert engine_type == invalid_engine From 643583252e5ca91df236e37c0f93d16a4bc2d506 Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Mon, 4 Mar 2024 11:26:48 -0500 Subject: [PATCH 2/3] modified random tests --- tests/test_random.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_random.py b/tests/test_random.py index f1c4063..d46cd8e 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -5,14 +5,6 @@ import arrayfire_wrapper.dtypes as dtypes import arrayfire_wrapper.lib as wrapper -invalid_shape = ( - random.randint(1, 10), - random.randint(1, 10), - random.randint(1, 10), - random.randint(1, 10), - random.randint(1, 10), -) - @pytest.mark.parametrize( "shape", From eef89c05401a9beb19188bcadd0658ee7f9d699c Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Tue, 5 Mar 2024 10:16:48 -0500 Subject: [PATCH 3/3] reformatted manage_array file --- arrayfire_wrapper/lib/create_and_modify_array/manage_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrayfire_wrapper/lib/create_and_modify_array/manage_array.py b/arrayfire_wrapper/lib/create_and_modify_array/manage_array.py index 52d4cba..34ad512 100644 --- a/arrayfire_wrapper/lib/create_and_modify_array/manage_array.py +++ b/arrayfire_wrapper/lib/create_and_modify_array/manage_array.py @@ -166,7 +166,7 @@ def get_scalar(arr: AFArray, dtype: Dtype, /) -> int | float | complex | bool | out = dtype.c_type() call_from_clib(get_scalar.__name__, ctypes.pointer(out), arr) if dtype == c32 or dtype == c64: - return complex(out[0], out[1]) # type: ignore + return complex(out[0], out[1]) # type: ignore else: return cast(int | float | complex | bool | None, out.value)