Skip to content

Commit ea2c627

Browse files
committed
Merge branch 'add-deth'
2 parents 2702823 + ac8c116 commit ea2c627

File tree

5 files changed

+254
-23
lines changed

5 files changed

+254
-23
lines changed

src/cholesky.rs

+19-19
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ where
9494
}
9595
}
9696

97-
impl<A, S> CholeskyDeterminant for CholeskyFactorized<S>
97+
impl<A, S> DeterminantC for CholeskyFactorized<S>
9898
where
9999
A: Absolute,
100100
S: Data<Elem = A>,
@@ -111,7 +111,7 @@ where
111111
}
112112
}
113113

114-
impl<A, S> CholeskyDeterminantInto for CholeskyFactorized<S>
114+
impl<A, S> DeterminantCInto for CholeskyFactorized<S>
115115
where
116116
A: Absolute,
117117
S: Data<Elem = A>,
@@ -123,7 +123,7 @@ where
123123
}
124124
}
125125

126-
impl<A, S> CholeskyInverse for CholeskyFactorized<S>
126+
impl<A, S> InverseC for CholeskyFactorized<S>
127127
where
128128
A: Scalar,
129129
S: Data<Elem = A>,
@@ -139,7 +139,7 @@ where
139139
}
140140
}
141141

142-
impl<A, S> CholeskyInverseInto for CholeskyFactorized<S>
142+
impl<A, S> InverseCInto for CholeskyFactorized<S>
143143
where
144144
A: Scalar,
145145
S: DataMut<Elem = A>,
@@ -154,7 +154,7 @@ where
154154
}
155155
}
156156

157-
impl<A, S> CholeskySolve<A> for CholeskyFactorized<S>
157+
impl<A, S> SolveC<A> for CholeskyFactorized<S>
158158
where
159159
A: Scalar,
160160
S: Data<Elem = A>,
@@ -255,7 +255,7 @@ where
255255
}
256256

257257
/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix reference
258-
pub trait CholeskyFactorize<S: Data> {
258+
pub trait FactorizeC<S: Data> {
259259
/// Computes the Cholesky decomposition of the Hermitian (or real
260260
/// symmetric) positive definite matrix.
261261
///
@@ -268,7 +268,7 @@ pub trait CholeskyFactorize<S: Data> {
268268
}
269269

270270
/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix
271-
pub trait CholeskyFactorizeInto<S: Data> {
271+
pub trait FactorizeCInto<S: Data> {
272272
/// Computes the Cholesky decomposition of the Hermitian (or real
273273
/// symmetric) positive definite matrix.
274274
///
@@ -280,7 +280,7 @@ pub trait CholeskyFactorizeInto<S: Data> {
280280
fn factorizec_into(self, UPLO) -> Result<CholeskyFactorized<S>>;
281281
}
282282

283-
impl<A, S> CholeskyFactorizeInto<S> for ArrayBase<S, Ix2>
283+
impl<A, S> FactorizeCInto<S> for ArrayBase<S, Ix2>
284284
where
285285
A: Scalar,
286286
S: DataMut<Elem = A>,
@@ -293,7 +293,7 @@ where
293293
}
294294
}
295295

296-
impl<A, Si> CholeskyFactorize<OwnedRepr<A>> for ArrayBase<Si, Ix2>
296+
impl<A, Si> FactorizeC<OwnedRepr<A>> for ArrayBase<Si, Ix2>
297297
where
298298
A: Scalar,
299299
Si: Data<Elem = A>,
@@ -308,7 +308,7 @@ where
308308

309309
/// Solve systems of linear equations with Hermitian (or real symmetric)
310310
/// positive definite coefficient matrices
311-
pub trait CholeskySolve<A: Scalar> {
311+
pub trait SolveC<A: Scalar> {
312312
/// Solves a system of linear equations `A * x = b` with Hermitian (or real
313313
/// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is
314314
/// the argument, and `x` is the successful result.
@@ -331,7 +331,7 @@ pub trait CholeskySolve<A: Scalar> {
331331
fn solvec_inplace<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
332332
}
333333

334-
impl<A, S> CholeskySolve<A> for ArrayBase<S, Ix2>
334+
impl<A, S> SolveC<A> for ArrayBase<S, Ix2>
335335
where
336336
A: Scalar,
337337
S: Data<Elem = A>,
@@ -345,22 +345,22 @@ where
345345
}
346346

347347
/// Inverse of Hermitian (or real symmetric) positive definite matrix ref
348-
pub trait CholeskyInverse {
348+
pub trait InverseC {
349349
type Output;
350350
/// Computes the inverse of the Hermitian (or real symmetric) positive
351351
/// definite matrix.
352352
fn invc(&self) -> Result<Self::Output>;
353353
}
354354

355355
/// Inverse of Hermitian (or real symmetric) positive definite matrix
356-
pub trait CholeskyInverseInto {
356+
pub trait InverseCInto {
357357
type Output;
358358
/// Computes the inverse of the Hermitian (or real symmetric) positive
359359
/// definite matrix.
360360
fn invc_into(self) -> Result<Self::Output>;
361361
}
362362

363-
impl<A, S> CholeskyInverse for ArrayBase<S, Ix2>
363+
impl<A, S> InverseC for ArrayBase<S, Ix2>
364364
where
365365
A: Scalar,
366366
S: Data<Elem = A>,
@@ -372,7 +372,7 @@ where
372372
}
373373
}
374374

375-
impl<A, S> CholeskyInverseInto for ArrayBase<S, Ix2>
375+
impl<A, S> InverseCInto for ArrayBase<S, Ix2>
376376
where
377377
A: Scalar,
378378
S: DataMut<Elem = A>,
@@ -385,7 +385,7 @@ where
385385
}
386386

387387
/// Determinant of Hermitian (or real symmetric) positive definite matrix ref
388-
pub trait CholeskyDeterminant {
388+
pub trait DeterminantC {
389389
type Output;
390390

391391
/// Computes the determinant of the Hermitian (or real symmetric) positive
@@ -395,15 +395,15 @@ pub trait CholeskyDeterminant {
395395

396396

397397
/// Determinant of Hermitian (or real symmetric) positive definite matrix
398-
pub trait CholeskyDeterminantInto {
398+
pub trait DeterminantCInto {
399399
type Output;
400400

401401
/// Computes the determinant of the Hermitian (or real symmetric) positive
402402
/// definite matrix.
403403
fn detc_into(self) -> Self::Output;
404404
}
405405

406-
impl<A, S> CholeskyDeterminant for ArrayBase<S, Ix2>
406+
impl<A, S> DeterminantC for ArrayBase<S, Ix2>
407407
where
408408
A: Scalar,
409409
S: Data<Elem = A>,
@@ -415,7 +415,7 @@ where
415415
}
416416
}
417417

418-
impl<A, S> CholeskyDeterminantInto for ArrayBase<S, Ix2>
418+
impl<A, S> DeterminantCInto for ArrayBase<S, Ix2>
419419
where
420420
A: Scalar,
421421
S: DataMut<Elem = A>,

src/lapack_traits/solveh.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ impl Solveh_ for $scalar {
2626
unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
2727
let (n, _) = l.size();
2828
let mut ipiv = vec![0; n as usize];
29-
let info = $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv);
30-
into_result(info, ipiv)
29+
if n == 0 {
30+
// Work around bug in LAPACKE functions.
31+
Ok(ipiv)
32+
} else {
33+
let info = $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv);
34+
into_result(info, ipiv)
35+
}
3136
}
3237

3338
unsafe fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {

src/solveh.rs

+121-2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
//! ```
5151
5252
use ndarray::*;
53+
use num_traits::{Float, One, Zero};
5354

5455
use super::convert::*;
5556
use super::error::*;
@@ -153,7 +154,7 @@ where
153154
S: DataMut<Elem = A>,
154155
{
155156
fn factorizeh_into(mut self) -> Result<BKFactorized<S>> {
156-
let ipiv = unsafe { A::bk(self.layout()?, UPLO::Upper, self.as_allocated_mut()?)? };
157+
let ipiv = unsafe { A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)? };
157158
Ok(BKFactorized {
158159
a: self,
159160
ipiv: ipiv,
@@ -168,7 +169,7 @@ where
168169
{
169170
fn factorizeh(&self) -> Result<BKFactorized<OwnedRepr<A>>> {
170171
let mut a: Array2<A> = replicate(self);
171-
let ipiv = unsafe { A::bk(a.layout()?, UPLO::Upper, a.as_allocated_mut()?)? };
172+
let ipiv = unsafe { A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)? };
172173
Ok(BKFactorized { a: a, ipiv: ipiv })
173174
}
174175
}
@@ -249,3 +250,121 @@ where
249250
f.invh_into()
250251
}
251252
}
253+
254+
/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
255+
pub trait DeterminantH {
256+
type Output;
257+
258+
/// Computes the determinant of the Hermitian (or real symmetric) matrix.
259+
fn deth(&self) -> Self::Output;
260+
}
261+
262+
/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
263+
pub trait DeterminantHInto {
264+
type Output;
265+
266+
/// Computes the determinant of the Hermitian (or real symmetric) matrix.
267+
fn deth_into(self) -> Self::Output;
268+
}
269+
270+
fn bk_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> A::Real
271+
where
272+
P: Iterator<Item = i32>,
273+
S: Data<Elem = A>,
274+
A: Scalar,
275+
{
276+
let mut sign = A::Real::one();
277+
let mut ln_det = A::Real::zero();
278+
let mut ipiv_enum = ipiv_iter.enumerate();
279+
while let Some((k, ipiv_k)) = ipiv_enum.next() {
280+
debug_assert!(k < a.rows() && k < a.cols());
281+
if ipiv_k > 0 {
282+
// 1x1 block at k, must be real.
283+
let elem = unsafe { a.uget((k, k)) }.real();
284+
debug_assert_eq!(elem.imag(), Zero::zero());
285+
sign = sign * elem.signum();
286+
ln_det = ln_det + elem.abs().ln();
287+
} else {
288+
// 2x2 block at k..k+2.
289+
290+
// Upper left diagonal elem, must be real.
291+
let upper_diag = unsafe { a.uget((k, k)) }.real();
292+
debug_assert_eq!(upper_diag.imag(), Zero::zero());
293+
294+
// Lower right diagonal elem, must be real.
295+
let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.real();
296+
debug_assert_eq!(lower_diag.imag(), Zero::zero());
297+
298+
// Off-diagonal elements, can be complex.
299+
let off_diag = match uplo {
300+
UPLO::Upper => unsafe { a.uget((k, k + 1)) },
301+
UPLO::Lower => unsafe { a.uget((k + 1, k)) },
302+
};
303+
304+
// Determinant of 2x2 block.
305+
let block_det = upper_diag * lower_diag - off_diag.abs_sqr();
306+
sign = sign * block_det.signum();
307+
ln_det = ln_det + block_det.abs().ln();
308+
309+
// Skip the k+1 ipiv value.
310+
ipiv_enum.next();
311+
}
312+
}
313+
sign * ln_det.exp()
314+
}
315+
316+
impl<A, S> DeterminantH for BKFactorized<S>
317+
where
318+
A: Scalar,
319+
S: Data<Elem = A>,
320+
{
321+
type Output = A::Real;
322+
323+
fn deth(&self) -> A::Real {
324+
bk_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
325+
}
326+
}
327+
328+
impl<A, S> DeterminantHInto for BKFactorized<S>
329+
where
330+
A: Scalar,
331+
S: Data<Elem = A>,
332+
{
333+
type Output = A::Real;
334+
335+
fn deth_into(self) -> A::Real {
336+
bk_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
337+
}
338+
}
339+
340+
impl<A, S> DeterminantH for ArrayBase<S, Ix2>
341+
where
342+
A: Scalar,
343+
S: Data<Elem = A>,
344+
{
345+
type Output = Result<A::Real>;
346+
347+
fn deth(&self) -> Result<A::Real> {
348+
match self.factorizeh() {
349+
Ok(fac) => Ok(fac.deth()),
350+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
351+
Err(err) => Err(err),
352+
}
353+
}
354+
}
355+
356+
impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
357+
where
358+
A: Scalar,
359+
S: DataMut<Elem = A>,
360+
{
361+
type Output = Result<A::Real>;
362+
363+
fn deth_into(self) -> Result<A::Real> {
364+
match self.factorizeh_into() {
365+
Ok(fac) => Ok(fac.deth_into()),
366+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
367+
Err(err) => Err(err),
368+
}
369+
}
370+
}

src/types.rs

+8
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ pub fn into_scalar<T: Scalar>(f: f64) -> T {
8787
pub trait AssociatedReal: Sized {
8888
type Real: RealScalar;
8989
fn inject(Self::Real) -> Self;
90+
/// Returns the real part of `self`.
91+
fn real(self) -> Self::Real;
92+
/// Returns the imaginary part of `self`.
93+
fn imag(self) -> Self::Real;
9094
fn add_real(self, Self::Real) -> Self;
9195
fn sub_real(self, Self::Real) -> Self;
9296
fn mul_real(self, Self::Real) -> Self;
@@ -141,6 +145,8 @@ macro_rules! impl_traits {
141145
impl AssociatedReal for $real {
142146
type Real = $real;
143147
fn inject(r: Self::Real) -> Self { r }
148+
fn real(self) -> Self::Real { self }
149+
fn imag(self) -> Self::Real { 0. }
144150
fn add_real(self, r: Self::Real) -> Self { self + r }
145151
fn sub_real(self, r: Self::Real) -> Self { self - r }
146152
fn mul_real(self, r: Self::Real) -> Self { self * r }
@@ -150,6 +156,8 @@ impl AssociatedReal for $real {
150156
impl AssociatedReal for $complex {
151157
type Real = $real;
152158
fn inject(r: Self::Real) -> Self { Self::new(r, 0.0) }
159+
fn real(self) -> Self::Real { self.re }
160+
fn imag(self) -> Self::Real { self.im }
153161
fn add_real(self, r: Self::Real) -> Self { self + r }
154162
fn sub_real(self, r: Self::Real) -> Self { self - r }
155163
fn mul_real(self, r: Self::Real) -> Self { self * r }

0 commit comments

Comments
 (0)