Skip to content

Commit 886f411

Browse files
Chaluvadiroaffix
Chaluvadi
authored andcommitted
added unit tests for the upper function
1 parent fab31a7 commit 886f411

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

Diff for: tests/test_upper.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
3+
import arrayfire_wrapper.dtypes as dtypes
4+
from arrayfire_wrapper.lib.create_and_modify_array.create_array.constant import constant
5+
from arrayfire_wrapper.lib.create_and_modify_array.create_array.diag import diag_extract
6+
from arrayfire_wrapper.lib.create_and_modify_array.create_array.upper import upper
7+
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_scalar
8+
9+
10+
@pytest.mark.parametrize(
11+
"shape",
12+
[
13+
(3, 3),
14+
(3, 3, 3),
15+
(3, 3, 3, 3),
16+
],
17+
)
18+
def test_diag_is_unit(shape: tuple) -> None:
19+
"""Test if when is_unit_diag in lower returns an array with a unit diagonal"""
20+
dtype = dtypes.s64
21+
constant_array = constant(3, shape, dtype)
22+
23+
lower_array = upper(constant_array, True)
24+
diagonal = diag_extract(lower_array, 0)
25+
diagonal_value = get_scalar(diagonal, dtype)
26+
27+
assert diagonal_value == 1
28+
29+
30+
@pytest.mark.parametrize(
31+
"shape",
32+
[
33+
(3, 3),
34+
(3, 3, 3),
35+
(3, 3, 3, 3),
36+
],
37+
)
38+
def test_is_original(shape: tuple) -> None:
39+
"""Test if is_original keeps the diagonal the same as the original array"""
40+
dtype = dtypes.s64
41+
constant_array = constant(3, shape, dtype)
42+
original_value = get_scalar(constant_array, dtype)
43+
44+
lower_array = upper(constant_array, False)
45+
diagonal = diag_extract(lower_array, 0)
46+
diagonal_value = get_scalar(diagonal, dtype)
47+
48+
assert original_value == diagonal_value

0 commit comments

Comments
 (0)