Skip to content

Commit f24e447

Browse files
authored
Merge pull request #140 from rust-math/random_generate
random unitary/regular matrices
2 parents 0c1e4f6 + 5102019 commit f24e447

File tree

2 files changed

+68
-38
lines changed

2 files changed

+68
-38
lines changed

src/generate.rs

+28
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::ops::*;
66

77
use super::convert::*;
88
use super::error::*;
9+
use super::qr::*;
910
use super::types::*;
1011

1112
/// Hermite conjugate matrix
@@ -34,6 +35,33 @@ where
3435
ArrayBase::from_shape_fn(sh, |_| A::randn(&mut rng))
3536
}
3637

38+
/// Generate random unitary matrix using QR decomposition
39+
///
40+
/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose.
41+
pub fn random_unitary<A>(n: usize) -> Array2<A>
42+
where
43+
A: Scalar + RandNormal,
44+
{
45+
let a: Array2<A> = random((n, n));
46+
let (q, _r) = a.qr_into().unwrap();
47+
q
48+
}
49+
50+
/// Generate random regular matrix
51+
///
52+
/// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose.
53+
pub fn random_regular<A>(n: usize) -> Array2<A>
54+
where
55+
A: Scalar + RandNormal,
56+
{
57+
let a: Array2<A> = random((n, n));
58+
let (q, mut r) = a.qr_into().unwrap();
59+
for i in 0..n {
60+
r[(i, i)] = A::from_f64(1.0) + AssociatedReal::inject(r[(i, i)].abs());
61+
}
62+
q.dot(&r)
63+
}
64+
3765
/// Random Hermite matrix
3866
pub fn random_hermite<A, S>(n: usize) -> ArrayBase<S, Ix2>
3967
where

tests/det.rs

+40-38
Original file line numberDiff line numberDiff line change
@@ -100,46 +100,48 @@ fn det_zero_nonsquare() {
100100

101101
#[test]
102102
fn det() {
103-
macro_rules! det {
104-
($elem:ty, $shape:expr, $rtol:expr) => {
105-
let a: Array2<$elem> = random($shape);
106-
println!("a = \n{:?}", a);
107-
let det = det_naive(&a);
108-
let sign = det.div_real(det.abs());
109-
let ln_det = det.abs().ln();
110-
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
111-
{
112-
let result = a.factorize().unwrap().sln_det().unwrap();
113-
assert_rclose!(result.0, sign, $rtol);
114-
assert_rclose!(result.1, ln_det, $rtol);
115-
}
116-
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
117-
{
118-
let result = a.factorize().unwrap().sln_det_into().unwrap();
119-
assert_rclose!(result.0, sign, $rtol);
120-
assert_rclose!(result.1, ln_det, $rtol);
121-
}
122-
assert_rclose!(a.det().unwrap(), det, $rtol);
123-
{
124-
let result = a.sln_det().unwrap();
125-
assert_rclose!(result.0, sign, $rtol);
126-
assert_rclose!(result.1, ln_det, $rtol);
127-
}
128-
assert_rclose!(a.clone().det_into().unwrap(), det, $rtol);
129-
{
130-
let result = a.sln_det_into().unwrap();
131-
assert_rclose!(result.0, sign, $rtol);
132-
assert_rclose!(result.1, ln_det, $rtol);
133-
}
134-
};
103+
fn det_impl<A, Tol>(a: Array2<A>, rtol: Tol)
104+
where
105+
A: Scalar<Real = Tol>,
106+
Tol: RealScalar<Real = Tol>,
107+
{
108+
let det = det_naive(&a);
109+
let sign = det.div_real(det.abs());
110+
let ln_det = det.abs().ln();
111+
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, rtol);
112+
{
113+
let result = a.factorize().unwrap().sln_det().unwrap();
114+
assert_rclose!(result.0, sign, rtol);
115+
assert_rclose!(result.1, ln_det, rtol);
116+
}
117+
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, rtol);
118+
{
119+
let result = a.factorize().unwrap().sln_det_into().unwrap();
120+
assert_rclose!(result.0, sign, rtol);
121+
assert_rclose!(result.1, ln_det, rtol);
122+
}
123+
assert_rclose!(a.det().unwrap(), det, rtol);
124+
{
125+
let result = a.sln_det().unwrap();
126+
assert_rclose!(result.0, sign, rtol);
127+
assert_rclose!(result.1, ln_det, rtol);
128+
}
129+
assert_rclose!(a.clone().det_into().unwrap(), det, rtol);
130+
{
131+
let result = a.sln_det_into().unwrap();
132+
assert_rclose!(result.0, sign, rtol);
133+
assert_rclose!(result.1, ln_det, rtol);
134+
}
135135
}
136136
for rows in 1..5 {
137-
for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] {
138-
det!(f64, shape, 1e-9);
139-
det!(f32, shape, 1e-4);
140-
det!(c64, shape, 1e-9);
141-
det!(c32, shape, 1e-4);
142-
}
137+
det_impl(random_regular::<f64>(rows), 1e-9);
138+
det_impl(random_regular::<f32>(rows), 1e-4);
139+
det_impl(random_regular::<c64>(rows), 1e-9);
140+
det_impl(random_regular::<c32>(rows), 1e-4);
141+
det_impl(random_regular::<f64>(rows).t().to_owned(), 1e-9);
142+
det_impl(random_regular::<f32>(rows).t().to_owned(), 1e-4);
143+
det_impl(random_regular::<c64>(rows).t().to_owned(), 1e-9);
144+
det_impl(random_regular::<c32>(rows).t().to_owned(), 1e-4);
143145
}
144146
}
145147

0 commit comments

Comments
 (0)