Skip to content

Commit 6a3bae1

Browse files
Add sycl_queue and usm_type tests
1 parent 6541bf0 commit 6a3bae1

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

dpnp/tests/test_sycl_queue.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from numpy.testing import assert_array_equal, assert_raises
1010

1111
import dpnp
12+
import dpnp.linalg
1213
from dpnp.dpnp_array import dpnp_array
1314
from dpnp.dpnp_utils import get_usm_allocations
1415

@@ -1582,6 +1583,20 @@ def test_lu_factor(self, data, device):
15821583
param_queue = param.sycl_queue
15831584
assert_sycl_queue_equal(param_queue, a.sycl_queue)
15841585

1586+
@pytest.mark.parametrize(
1587+
"data",
1588+
[[1.0, 2.0], numpy.empty((2, 0))],
1589+
)
1590+
def test_lu_solve(self, data, device):
1591+
a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device)
1592+
lu, piv = dpnp.linalg.lu_factor(a)
1593+
b = dpnp.array(data, device=device)
1594+
1595+
result = dpnp.linalg.lu_solve((lu, piv), b)
1596+
1597+
assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue)
1598+
assert_sycl_queue_equal(result.sycl_queue, b.sycl_queue)
1599+
15851600
@pytest.mark.parametrize("n", [-1, 0, 1, 2, 3])
15861601
def test_matrix_power(self, n, device):
15871602
x = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device)

dpnp/tests/test_usm_type.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,24 @@ def test_lu_factor(self, data, usm_type):
14611461
for param in result:
14621462
assert param.usm_type == a.usm_type
14631463

1464+
@pytest.mark.parametrize("usm_type_rhs", list_of_usm_types)
1465+
@pytest.mark.parametrize(
1466+
"data",
1467+
[[1.0, 2.0], numpy.empty((2, 0))],
1468+
)
1469+
def test_lu_solve(self, data, usm_type, usm_type_rhs):
1470+
a = dpnp.array(data, usm_type=usm_type)
1471+
lu, piv = dpnp.linalg.lu_factor(a)
1472+
b = dpnp.array(data, usm_type=usm_type_rhs)
1473+
1474+
result = dpnp.linalg.lu_solve((lu, piv), b)
1475+
1476+
assert lu.usm_type == usm_type
1477+
assert b.usm_type == usm_type_rhs
1478+
assert result.usm_type == du.get_coerced_usm_type(
1479+
[usm_type, usm_type_rhs]
1480+
)
1481+
14641482
@pytest.mark.parametrize("n", [-1, 0, 1, 2, 3])
14651483
def test_matrix_power(self, n, usm_type):
14661484
a = dpnp.array([[1, 2], [3, 5]], usm_type=usm_type)

0 commit comments

Comments
 (0)