Skip to content

Commit 8711de9

Browse files
sakchalroaffix
authored andcommitted
unit tests for iotafunction
1 parent d8caaeb commit 8711de9

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

Diff for: tests/test_iota.py

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import random
2+
3+
import numpy as np
4+
import pytest
5+
6+
import arrayfire_wrapper.dtypes as dtypes
7+
from arrayfire_wrapper.lib.create_and_modify_array.create_array.iota import iota
8+
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_dims, get_type
9+
from arrayfire_wrapper.lib.create_and_modify_array.manage_device import get_dbl_support
10+
11+
12+
@pytest.mark.parametrize(
13+
"shape",
14+
[
15+
(),
16+
(random.randint(1, 10), 1),
17+
(random.randint(1, 10), random.randint(1, 10)),
18+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
19+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
20+
],
21+
)
22+
def test_iota_shape(shape: tuple) -> None:
23+
"""Test if identity creates an array with the correct shape"""
24+
dtype = dtypes.s16
25+
t_shape = (1, 1)
26+
27+
result = iota(shape, t_shape, dtype)
28+
29+
assert get_dims(result)[0:len(shape)] == shape
30+
31+
32+
def test_iota_invalid_shape() -> None:
33+
"""Test if iota handles a shape with greater than 4 dimensions"""
34+
with pytest.raises(TypeError) as excinfo:
35+
invalid_shape = (
36+
random.randint(1, 10),
37+
random.randint(1, 10),
38+
random.randint(1, 10),
39+
random.randint(1, 10),
40+
random.randint(1, 10),
41+
)
42+
dtype = dtypes.s16
43+
t_shape = ()
44+
45+
iota(invalid_shape, t_shape, dtype)
46+
47+
assert f"CShape.__init__() takes from 1 to 5 positional arguments but {len(invalid_shape) + 1} were given" in str(
48+
excinfo.value
49+
)
50+
51+
52+
@pytest.mark.parametrize(
53+
"t_shape",
54+
[
55+
(1,),
56+
(random.randint(1, 10), 1),
57+
(random.randint(1, 10), random.randint(1, 10)),
58+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
59+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
60+
],
61+
)
62+
def test_iota_tshape(t_shape: tuple) -> None:
63+
"""Test if iota properly uses t_shape to change the size of the array and result in correct dimensions"""
64+
shape = np.array([2, 2])
65+
dtype = dtypes.s64
66+
67+
if len(shape.shape) < len(t_shape):
68+
shape = np.append(shape, np.ones(len(t_shape) - len(shape), dtype=int))
69+
70+
result_shape = shape * t_shape
71+
72+
result = iota(tuple(shape), t_shape, dtype)
73+
74+
result_dims = tuple(int(value) for value in get_dims(result))
75+
76+
assert (result_dims[0:len(result_shape)] == result_shape).all()
77+
78+
79+
@pytest.mark.parametrize(
80+
"t_shape",
81+
[
82+
(0,),
83+
(-1, -1),
84+
],
85+
)
86+
def test_iota_tshape_zero(t_shape: tuple) -> None:
87+
"""Test it iota properly handles negative or zero t_shapes"""
88+
with pytest.raises(RuntimeError):
89+
shape = (2, 2)
90+
91+
dtype = dtypes.s16
92+
93+
iota(shape, t_shape, dtype)
94+
95+
96+
def test_iota_tshape_float() -> None:
97+
"""Test it iota properly handles float t_shapes"""
98+
with pytest.raises(TypeError):
99+
shape = (2, 2)
100+
t_shape = (1.5, 1.5)
101+
102+
dtype = dtypes.s16
103+
104+
iota(shape, t_shape, dtype)
105+
106+
107+
def test_iota_tshape_invalid() -> None:
108+
"""Test it iota properly handles a tshape with greater than 4 dimensions"""
109+
with pytest.raises(TypeError):
110+
shape = (2, 2)
111+
invalid_tshape = (
112+
random.randint(1, 10),
113+
random.randint(1, 10),
114+
random.randint(1, 10),
115+
random.randint(1, 10),
116+
random.randint(1, 10),
117+
)
118+
dtype = dtypes.s16
119+
120+
iota(shape, invalid_tshape, dtype)
121+
122+
123+
@pytest.mark.parametrize(
124+
"dtype_index",
125+
[i for i in range(13)],
126+
)
127+
def test_iota_dtype(dtype_index: int) -> None:
128+
"""Test if iota creates an array with the correct dtype"""
129+
if (dtype_index in [1, 4]) or (dtype_index in [2, 3] and not get_dbl_support()):
130+
pytest.skip()
131+
132+
shape = (5, 5)
133+
t_shape = (2, 2)
134+
dtype = dtypes.c_api_value_to_dtype(dtype_index)
135+
136+
result = iota(shape, t_shape, dtype)
137+
138+
assert dtypes.c_api_value_to_dtype(get_type(result)) == dtype

0 commit comments

Comments
 (0)