Skip to content

Commit 03fe1ca

Browse files
authored
Merge pull request #150 from rust-ndarray/krylov
Orthogonalizer trait
2 parents 55d0e8c + 1db5de4 commit 03fe1ca

File tree

5 files changed

+244
-189
lines changed

5 files changed

+244
-189
lines changed

src/krylov/mgs.rs

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
//! Modified Gram-Schmit orthogonalizer
2+
3+
use super::*;
4+
use crate::{generate::*, inner::*, norm::Norm};
5+
6+
/// Iterative orthogonalizer using modified Gram-Schmit procedure
7+
///
8+
/// ```rust
9+
/// # use ndarray::*;
10+
/// # use ndarray_linalg::{krylov::*, *};
11+
/// let mut mgs = MGS::new(3);
12+
/// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
13+
/// close_l2(&coef, &array![1.0], 1e-9);
14+
///
15+
/// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
16+
/// close_l2(&coef, &array![1.0, 1.0], 1e-9);
17+
///
18+
/// // Fail if the vector is linearly dependent
19+
/// assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err());
20+
///
21+
/// // You can get coefficients of dependent vector
22+
/// if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) {
23+
/// close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9);
24+
/// }
25+
/// ```
26+
#[derive(Debug, Clone)]
27+
pub struct MGS<A> {
28+
/// Dimension of base space
29+
dimension: usize,
30+
/// Basis of spanned space
31+
q: Vec<Array1<A>>,
32+
}
33+
34+
impl<A: Scalar> MGS<A> {
35+
/// Create an empty orthogonalizer
36+
pub fn new(dimension: usize) -> Self {
37+
Self {
38+
dimension,
39+
q: Vec::new(),
40+
}
41+
}
42+
}
43+
44+
impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
45+
type Elem = A;
46+
47+
fn dim(&self) -> usize {
48+
self.dimension
49+
}
50+
51+
fn len(&self) -> usize {
52+
self.q.len()
53+
}
54+
55+
fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
56+
where
57+
A: Lapack,
58+
S: DataMut<Elem = A>,
59+
{
60+
assert_eq!(a.len(), self.dim());
61+
let mut coef = Array1::zeros(self.len() + 1);
62+
for i in 0..self.len() {
63+
let q = &self.q[i];
64+
let c = q.inner(&a);
65+
azip!(mut a (&mut *a), q (q) in { *a = *a - c * q } );
66+
coef[i] = c;
67+
}
68+
let nrm = a.norm_l2();
69+
coef[self.len()] = A::from_real(nrm);
70+
coef
71+
}
72+
73+
fn append<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
74+
where
75+
A: Lapack,
76+
S: Data<Elem = A>,
77+
{
78+
let mut a = a.into_owned();
79+
let coef = self.orthogonalize(&mut a);
80+
let nrm = coef[coef.len() - 1].re();
81+
if nrm < rtol {
82+
// Linearly dependent
83+
return Err(coef);
84+
}
85+
azip!(mut a in { *a = *a / A::from_real(nrm) });
86+
self.q.push(a);
87+
Ok(coef)
88+
}
89+
90+
fn get_q(&self) -> Q<A> {
91+
hstack(&self.q).unwrap()
92+
}
93+
}
94+
95+
/// Online QR decomposition of vectors using modified Gram-Schmit algorithm
96+
pub fn mgs<A, S>(
97+
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
98+
dim: usize,
99+
rtol: A::Real,
100+
strategy: Strategy,
101+
) -> (Q<A>, R<A>)
102+
where
103+
A: Scalar + Lapack,
104+
S: Data<Elem = A>,
105+
{
106+
let mgs = MGS::new(dim);
107+
qr(iter, mgs, rtol, strategy)
108+
}

src/krylov/mod.rs

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
//! Krylov subspace
2+
3+
use crate::types::*;
4+
use ndarray::*;
5+
6+
mod mgs;
7+
8+
pub use mgs::{mgs, MGS};
9+
10+
/// Q-matrix
11+
///
12+
/// - Maybe **NOT** square
13+
/// - Unitary for existing columns
14+
///
15+
pub type Q<A> = Array2<A>;
16+
17+
/// R-matrix
18+
///
19+
/// - Maybe **NOT** square
20+
/// - Upper triangle
21+
///
22+
pub type R<A> = Array2<A>;
23+
24+
/// Trait for creating orthogonal basis from iterator of arrays
25+
pub trait Orthogonalizer {
26+
type Elem: Scalar;
27+
28+
/// Dimension of input array
29+
fn dim(&self) -> usize;
30+
31+
/// Number of cached basis
32+
fn len(&self) -> usize;
33+
34+
/// check if the basis spans entire space
35+
fn is_full(&self) -> bool {
36+
self.len() == self.dim()
37+
}
38+
39+
fn is_empty(&self) -> bool {
40+
self.len() == 0
41+
}
42+
43+
/// Orthogonalize given vector using current basis
44+
///
45+
/// Panic
46+
/// -------
47+
/// - if the size of the input array mismatches to the dimension
48+
///
49+
fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<Self::Elem>
50+
where
51+
S: DataMut<Elem = Self::Elem>;
52+
53+
/// Add new vector if the residual is larger than relative tolerance
54+
///
55+
/// Returns
56+
/// --------
57+
/// Coefficients to the `i`-th Q-vector
58+
///
59+
/// - The size of array must be `self.len() + 1`
60+
/// - The last element is the residual norm of input vector
61+
///
62+
/// Panic
63+
/// -------
64+
/// - if the size of the input array mismatches to the dimension
65+
///
66+
fn append<S>(
67+
&mut self,
68+
a: ArrayBase<S, Ix1>,
69+
rtol: <Self::Elem as Scalar>::Real,
70+
) -> Result<Array1<Self::Elem>, Array1<Self::Elem>>
71+
where
72+
S: DataMut<Elem = Self::Elem>;
73+
74+
/// Get Q-matrix of generated basis
75+
fn get_q(&self) -> Q<Self::Elem>;
76+
}
77+
78+
/// Strategy for linearly dependent vectors appearing in iterative QR decomposition
79+
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
80+
pub enum Strategy {
81+
/// Terminate iteration if dependent vector comes
82+
Terminate,
83+
84+
/// Skip dependent vector
85+
Skip,
86+
87+
/// Orthogonalize dependent vector without adding to Q,
88+
/// i.e. R must be non-square like following:
89+
///
90+
/// ```text
91+
/// x x x x x
92+
/// 0 x x x x
93+
/// 0 0 0 x x
94+
/// 0 0 0 0 x
95+
/// ```
96+
Full,
97+
}
98+
99+
/// Online QR decomposition using arbitrary orthogonalizer
100+
pub fn qr<A, S>(
101+
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
102+
mut ortho: impl Orthogonalizer<Elem = A>,
103+
rtol: A::Real,
104+
strategy: Strategy,
105+
) -> (Q<A>, R<A>)
106+
where
107+
A: Scalar + Lapack,
108+
S: Data<Elem = A>,
109+
{
110+
assert_eq!(ortho.len(), 0);
111+
112+
let mut coefs = Vec::new();
113+
for a in iter {
114+
match ortho.append(a.into_owned(), rtol) {
115+
Ok(coef) => coefs.push(coef),
116+
Err(coef) => match strategy {
117+
Strategy::Terminate => break,
118+
Strategy::Skip => continue,
119+
Strategy::Full => coefs.push(coef),
120+
},
121+
}
122+
}
123+
let n = ortho.len();
124+
let m = coefs.len();
125+
let mut r = Array2::zeros((n, m).f());
126+
for j in 0..m {
127+
for i in 0..n {
128+
if i < coefs[j].len() {
129+
r[(i, j)] = coefs[j][i];
130+
}
131+
}
132+
}
133+
(ortho.get_q(), r)
134+
}

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ pub mod eigh;
4444
pub mod error;
4545
pub mod generate;
4646
pub mod inner;
47+
pub mod krylov;
4748
pub mod lapack;
4849
pub mod layout;
49-
pub mod mgs;
5050
pub mod norm;
5151
pub mod operator;
5252
pub mod opnorm;

0 commit comments

Comments
 (0)