11import pytest
22
33import arrayfire_wrapper .dtypes as dtypes
4- from arrayfire_wrapper .lib .create_and_modify_array .create_array .constant import constant
5- from arrayfire_wrapper .lib .create_and_modify_array .create_array .diag import diag_extract
6- from arrayfire_wrapper .lib .create_and_modify_array .create_array .upper import upper
7- from arrayfire_wrapper .lib .create_and_modify_array .manage_array import get_scalar
4+ import arrayfire_wrapper .lib as wrapper
85
96
107@pytest .mark .parametrize (
1815def test_diag_is_unit (shape : tuple ) -> None :
1916 """Test if when is_unit_diag in lower returns an array with a unit diagonal"""
2017 dtype = dtypes .s64
21- constant_array = constant (3 , shape , dtype )
18+ constant_array = wrapper . constant (3 , shape , dtype )
2219
23- lower_array = upper (constant_array , True )
24- diagonal = diag_extract (lower_array , 0 )
25- diagonal_value = get_scalar (diagonal , dtype )
20+ lower_array = wrapper . upper (constant_array , True )
21+ diagonal = wrapper . diag_extract (lower_array , 0 )
22+ diagonal_value = wrapper . get_scalar (diagonal , dtype )
2623
2724 assert diagonal_value == 1
2825
@@ -38,11 +35,11 @@ def test_diag_is_unit(shape: tuple) -> None:
3835def test_is_original (shape : tuple ) -> None :
3936 """Test if is_original keeps the diagonal the same as the original array"""
4037 dtype = dtypes .s64
41- constant_array = constant (3 , shape , dtype )
42- original_value = get_scalar (constant_array , dtype )
38+ constant_array = wrapper . constant (3 , shape , dtype )
39+ original_value = wrapper . get_scalar (constant_array , dtype )
4340
44- lower_array = upper (constant_array , False )
45- diagonal = diag_extract (lower_array , 0 )
46- diagonal_value = get_scalar (diagonal , dtype )
41+ lower_array = wrapper . upper (constant_array , False )
42+ diagonal = wrapper . diag_extract (lower_array , 0 )
43+ diagonal_value = wrapper . get_scalar (diagonal , dtype )
4744
4845 assert original_value == diagonal_value
0 commit comments