Skip to content

Commit 879a629

Browse files
Optimize _batched_lu_factor for dpnp.linalg.det/slogdet (#2572)
This PR suggests optimizing and simplifying `_batched_lu_factor` logic by replacing per-iteration allocations with a single preallocated buffer and batch-axis views which improves performance in `dpnp.linalg.det` and `dpnp.linalg.slogdet`
1 parent d289709 commit 879a629

File tree

2 files changed

+34
-70
lines changed

2 files changed

+34
-70
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4242
* Improved performance of `dpnp.isclose` function by implementing a dedicated kernel for scalar `rtol` and `atol` arguments [#2540](https://github.com/IntelPython/dpnp/pull/2540)
4343
* Extended `dpnp.pad` to support `pad_width` keyword as a dictionary [#2535](https://github.com/IntelPython/dpnp/pull/2535)
4444
* Redesigned `dpnp.erf` function through pybind11 extension of OneMKL call or dedicated kernel in `ufunc` namespace [#2551](https://github.com/IntelPython/dpnp/pull/2551)
45+
* Improved performance of batched implementation of `dpnp.linalg.det` and `dpnp.linalg.slogdet` [#2572](https://github.com/IntelPython/dpnp/pull/2572)
4546

4647
### Deprecated
4748

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 33 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -297,26 +297,27 @@ def _batched_lu_factor(a, res_type):
297297
batch_size = a.shape[0]
298298
a_usm_arr = dpnp.get_usm_ndarray(a)
299299

300+
# `a` must be copied because getrf/getrf_batch destroys the input matrix
301+
a_h = dpnp.empty_like(a, order="C", dtype=res_type)
302+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
303+
src=a_usm_arr,
304+
dst=a_h.get_array(),
305+
sycl_queue=a_sycl_queue,
306+
depends=_manager.submitted_events,
307+
)
308+
_manager.add_event_pair(ht_ev, copy_ev)
309+
310+
ipiv_h = dpnp.empty(
311+
(batch_size, n),
312+
dtype=dpnp.int64,
313+
order="C",
314+
usm_type=a_usm_type,
315+
sycl_queue=a_sycl_queue,
316+
)
317+
300318
if use_batch:
301-
# `a` must be copied because getrf_batch destroys the input matrix
302-
a_h = dpnp.empty_like(a, order="C", dtype=res_type)
303-
ipiv_h = dpnp.empty(
304-
(batch_size, n),
305-
dtype=dpnp.int64,
306-
order="C",
307-
usm_type=a_usm_type,
308-
sycl_queue=a_sycl_queue,
309-
)
310319
dev_info_h = [0] * batch_size
311320

312-
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
313-
src=a_usm_arr,
314-
dst=a_h.get_array(),
315-
sycl_queue=a_sycl_queue,
316-
depends=_manager.submitted_events,
317-
)
318-
_manager.add_event_pair(ht_ev, copy_ev)
319-
320321
ipiv_stride = n
321322
a_stride = a_h.strides[0]
322323

@@ -336,63 +337,25 @@ def _batched_lu_factor(a, res_type):
336337
)
337338
_manager.add_event_pair(ht_ev, getrf_ev)
338339

339-
dev_info_array = dpnp.array(
340-
dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue
341-
)
342-
343-
# Reshape the results back to their original shape
344-
a_h = a_h.reshape(orig_shape)
345-
ipiv_h = ipiv_h.reshape(orig_shape[:-1])
346-
dev_info_array = dev_info_array.reshape(orig_shape[:-2])
347-
348-
return (a_h, ipiv_h, dev_info_array)
349-
350-
# Initialize lists for storing arrays and events for each batch
351-
a_vecs = [None] * batch_size
352-
ipiv_vecs = [None] * batch_size
353-
dev_info_vecs = [None] * batch_size
354-
355-
dep_evs = _manager.submitted_events
356-
357-
# Process each batch
358-
for i in range(batch_size):
359-
# Copy each 2D slice to a new array because getrf will destroy
360-
# the input matrix
361-
a_vecs[i] = dpnp.empty_like(a[i], order="C", dtype=res_type)
362-
363-
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
364-
src=a_usm_arr[i],
365-
dst=a_vecs[i].get_array(),
366-
sycl_queue=a_sycl_queue,
367-
depends=dep_evs,
368-
)
369-
_manager.add_event_pair(ht_ev, copy_ev)
370-
371-
ipiv_vecs[i] = dpnp.empty(
372-
(n,),
373-
dtype=dpnp.int64,
374-
order="C",
375-
usm_type=a_usm_type,
376-
sycl_queue=a_sycl_queue,
377-
)
378-
dev_info_vecs[i] = [0]
340+
else:
341+
dev_info_h = [[0] for _ in range(batch_size)]
379342

380-
# Call the LAPACK extension function _getrf
381-
# to perform LU decomposition on each batch in 'a_vecs[i]'
382-
ht_ev, getrf_ev = li._getrf(
383-
a_sycl_queue,
384-
a_vecs[i].get_array(),
385-
ipiv_vecs[i].get_array(),
386-
dev_info_vecs[i],
387-
depends=[copy_ev],
388-
)
389-
_manager.add_event_pair(ht_ev, getrf_ev)
343+
# Sequential LU factorization using getrf per slice
344+
for i in range(batch_size):
345+
ht_ev, getrf_ev = li._getrf(
346+
a_sycl_queue,
347+
a_h[i].get_array(),
348+
ipiv_h[i].get_array(),
349+
dev_info_h[i],
350+
depends=[copy_ev],
351+
)
352+
_manager.add_event_pair(ht_ev, getrf_ev)
390353

391354
# Reshape the results back to their original shape
392-
out_a = dpnp.array(a_vecs, order="C").reshape(orig_shape)
393-
out_ipiv = dpnp.array(ipiv_vecs).reshape(orig_shape[:-1])
355+
out_a = a_h.reshape(orig_shape)
356+
out_ipiv = ipiv_h.reshape(orig_shape[:-1])
394357
out_dev_info = dpnp.array(
395-
dev_info_vecs, usm_type=a_usm_type, sycl_queue=a_sycl_queue
358+
dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue
396359
).reshape(orig_shape[:-2])
397360

398361
return (out_a, out_ipiv, out_dev_info)

0 commit comments

Comments
 (0)