forked from arrayfire/arrayfire-binary-python-wrapper
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_complex.py
135 lines (119 loc) · 4.6 KB
/
test_complex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import random
import pytest
import arrayfire_wrapper.dtypes as dtype
import arrayfire_wrapper.lib as wrapper
from tests.utility_functions import check_type_supported, get_all_types, get_float_types, get_real_types
@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", get_float_types())
def test_complex_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test complex operation across all supported data types."""
check_type_supported(dtype_name)
if dtype_name == dtype.f16:
pytest.skip()
tester = wrapper.randu(shape, dtype_name)
result = wrapper.cplx(tester)
assert wrapper.is_complex(result), f"Failed for dtype: {dtype_name}"
@pytest.mark.parametrize(
"invdtypes",
[
dtype.int32,
dtype.complex32,
],
)
def test_complex_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
"""Test complex operation for unsupported data types."""
with pytest.raises(RuntimeError):
shape = (5, 5)
out = wrapper.randu(shape, invdtypes)
wrapper.cplx(out)
@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", get_real_types())
def test_complex2_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test complex2 operation across all supported data types."""
check_type_supported(dtype_name)
lhs = wrapper.randu(shape, dtype_name)
rhs = wrapper.randu(shape, dtype_name)
result = wrapper.cplx2(lhs, rhs)
assert wrapper.is_complex(result), f"Failed for dtype: {dtype_name}"
@pytest.mark.parametrize(
"invdtypes",
[
dtype.c32,
],
)
def test_complex2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
"""Test complex2 operation for unsupported data types."""
with pytest.raises(RuntimeError):
shape = (5, 5)
lhs = wrapper.randu(shape, invdtypes)
rhs = wrapper.randu(shape, invdtypes)
wrapper.cplx2(lhs, rhs)
@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtypes", get_all_types())
def test_conj_supported_dtypes(shape: tuple, dtypes: dtype.Dtype) -> None:
"""Test conjugate operation for supported data types."""
check_type_supported(dtypes)
arr = wrapper.constant(7, shape, dtypes)
result = wrapper.conjg(arr)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"Failed for shape: {shape}, and dtype: {dtypes}" # noqa
@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtypes", get_all_types())
def test_imag_supported_dtypes(shape: tuple, dtypes: dtype.Dtype) -> None:
"""Test imaginary and real operations for supported data types."""
check_type_supported(dtypes)
arr = wrapper.randu(shape, dtypes)
real = wrapper.real(arr)
assert wrapper.is_real(real), f"Failed for shape: {shape}"
@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtypes", get_all_types())
def test_real_supported_dtypes(shape: tuple, dtypes: dtype.Dtype) -> None:
"""Test imaginary and real operations for supported data types."""
check_type_supported(dtypes)
arr = wrapper.randu(shape, dtypes)
real = wrapper.real(arr)
assert wrapper.is_real(real), f"Failed for shape: {shape}"