|
| 1 | +import pytest |
| 2 | + |
| 3 | +import arrayfire_wrapper.dtypes as dtypes |
| 4 | +from arrayfire_wrapper.lib.create_and_modify_array.create_array.constant import constant |
| 5 | +from arrayfire_wrapper.lib.create_and_modify_array.create_array.diag import diag_extract |
| 6 | +from arrayfire_wrapper.lib.create_and_modify_array.create_array.upper import upper |
| 7 | +from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_scalar |
| 8 | + |
| 9 | + |
| 10 | +@pytest.mark.parametrize( |
| 11 | + "shape", |
| 12 | + [ |
| 13 | + (3, 3), |
| 14 | + (3, 3, 3), |
| 15 | + (3, 3, 3, 3), |
| 16 | + ], |
| 17 | +) |
| 18 | +def test_diag_is_unit(shape: tuple) -> None: |
| 19 | + """Test if when is_unit_diag in lower returns an array with a unit diagonal""" |
| 20 | + dtype = dtypes.s64 |
| 21 | + constant_array = constant(3, shape, dtype) |
| 22 | + |
| 23 | + lower_array = upper(constant_array, True) |
| 24 | + diagonal = diag_extract(lower_array, 0) |
| 25 | + diagonal_value = get_scalar(diagonal, dtype) |
| 26 | + |
| 27 | + assert diagonal_value == 1 |
| 28 | + |
| 29 | + |
| 30 | +@pytest.mark.parametrize( |
| 31 | + "shape", |
| 32 | + [ |
| 33 | + (3, 3), |
| 34 | + (3, 3, 3), |
| 35 | + (3, 3, 3, 3), |
| 36 | + ], |
| 37 | +) |
| 38 | +def test_is_original(shape: tuple) -> None: |
| 39 | + """Test if is_original keeps the diagonal the same as the original array""" |
| 40 | + dtype = dtypes.s64 |
| 41 | + constant_array = constant(3, shape, dtype) |
| 42 | + original_value = get_scalar(constant_array, dtype) |
| 43 | + |
| 44 | + lower_array = upper(constant_array, False) |
| 45 | + diagonal = diag_extract(lower_array, 0) |
| 46 | + diagonal_value = get_scalar(diagonal, dtype) |
| 47 | + |
| 48 | + assert original_value == diagonal_value |
0 commit comments