@@ -2,10 +2,15 @@ use ndarray::*;
2
2
use ndarray_linalg:: * ;
3
3
4
4
#[ test]
5
- fn eigen_vector_manual ( ) {
5
+ fn fixed ( ) {
6
6
let a = arr2 ( & [ [ 3.0 , 1.0 , 1.0 ] , [ 1.0 , 3.0 , 1.0 ] , [ 1.0 , 1.0 , 3.0 ] ] ) ;
7
7
let ( e, vecs) : ( Array1 < _ > , Array2 < _ > ) = ( & a) . eigh ( UPLO :: Upper ) . unwrap ( ) ;
8
8
assert_close_l2 ! ( & e, & arr1( & [ 2.0 , 2.0 , 5.0 ] ) , 1.0e-7 ) ;
9
+
10
+ // Check eigenvectors are orthogonalized
11
+ let s = vecs. t ( ) . dot ( & vecs) ;
12
+ assert_close_l2 ! ( & s, & Array :: eye( 3 ) , 1.0e-7 ) ;
13
+
9
14
for ( i, v) in vecs. axis_iter ( Axis ( 1 ) ) . enumerate ( ) {
10
15
let av = a. dot ( & v) ;
11
16
let ev = v. mapv ( |x| e[ i] * x) ;
@@ -14,12 +19,53 @@ fn eigen_vector_manual() {
14
19
}
15
20
16
21
#[ test]
17
- fn diagonalize ( ) {
18
- let a = arr2 ( & [ [ 3.0 , 1.0 , 1.0 ] , [ 1.0 , 3.0 , 1.0 ] , [ 1.0 , 1.0 , 3.0 ] ] ) ;
22
+ fn fixed_t ( ) {
23
+ let a = arr2 ( & [ [ 3.0 , 1.0 , 1.0 ] , [ 1.0 , 3.0 , 1.0 ] , [ 1.0 , 1.0 , 3.0 ] ] ) . reversed_axes ( ) ;
19
24
let ( e, vecs) : ( Array1 < _ > , Array2 < _ > ) = ( & a) . eigh ( UPLO :: Upper ) . unwrap ( ) ;
20
- let s = vecs. t ( ) . dot ( & a) . dot ( & vecs) ;
21
- for i in 0 ..3 {
22
- assert_rclose ! ( e[ i] , s[ ( i, i) ] , 1e-7 ) ;
25
+ assert_close_l2 ! ( & e, & arr1( & [ 2.0 , 2.0 , 5.0 ] ) , 1.0e-7 ) ;
26
+
27
+ // Check eigenvectors are orthogonalized
28
+ let s = vecs. t ( ) . dot ( & vecs) ;
29
+ assert_close_l2 ! ( & s, & Array :: eye( 3 ) , 1.0e-7 ) ;
30
+
31
+ for ( i, v) in vecs. axis_iter ( Axis ( 1 ) ) . enumerate ( ) {
32
+ let av = a. dot ( & v) ;
33
+ let ev = v. mapv ( |x| e[ i] * x) ;
34
+ assert_close_l2 ! ( & av, & ev, 1.0e-7 ) ;
35
+ }
36
+ }
37
+
38
+ #[ test]
39
+ fn fixed_lower ( ) {
40
+ let a = arr2 ( & [ [ 3.0 , 1.0 , 1.0 ] , [ 1.0 , 3.0 , 1.0 ] , [ 1.0 , 1.0 , 3.0 ] ] ) ;
41
+ let ( e, vecs) : ( Array1 < _ > , Array2 < _ > ) = ( & a) . eigh ( UPLO :: Lower ) . unwrap ( ) ;
42
+ assert_close_l2 ! ( & e, & arr1( & [ 2.0 , 2.0 , 5.0 ] ) , 1.0e-7 ) ;
43
+
44
+ // Check eigenvectors are orthogonalized
45
+ let s = vecs. t ( ) . dot ( & vecs) ;
46
+ assert_close_l2 ! ( & s, & Array :: eye( 3 ) , 1.0e-7 ) ;
47
+
48
+ for ( i, v) in vecs. axis_iter ( Axis ( 1 ) ) . enumerate ( ) {
49
+ let av = a. dot ( & v) ;
50
+ let ev = v. mapv ( |x| e[ i] * x) ;
51
+ assert_close_l2 ! ( & av, & ev, 1.0e-7 ) ;
52
+ }
53
+ }
54
+
55
+ #[ test]
56
+ fn fixed_t_lower ( ) {
57
+ let a = arr2 ( & [ [ 3.0 , 1.0 , 1.0 ] , [ 1.0 , 3.0 , 1.0 ] , [ 1.0 , 1.0 , 3.0 ] ] ) . reversed_axes ( ) ;
58
+ let ( e, vecs) : ( Array1 < _ > , Array2 < _ > ) = ( & a) . eigh ( UPLO :: Lower ) . unwrap ( ) ;
59
+ assert_close_l2 ! ( & e, & arr1( & [ 2.0 , 2.0 , 5.0 ] ) , 1.0e-7 ) ;
60
+
61
+ // Check eigenvectors are orthogonalized
62
+ let s = vecs. t ( ) . dot ( & vecs) ;
63
+ assert_close_l2 ! ( & s, & Array :: eye( 3 ) , 1.0e-7 ) ;
64
+
65
+ for ( i, v) in vecs. axis_iter ( Axis ( 1 ) ) . enumerate ( ) {
66
+ let av = a. dot ( & v) ;
67
+ let ev = v. mapv ( |x| e[ i] * x) ;
68
+ assert_close_l2 ! ( & av, & ev, 1.0e-7 ) ;
23
69
}
24
70
}
25
71
@@ -48,3 +94,29 @@ fn ssqrt_t() {
48
94
println ! ( "ss = {:?}" , & ss) ;
49
95
assert_close_l2 ! ( & ss, & ans, 1e-7 ) ;
50
96
}
97
+
98
+ #[ test]
99
+ fn ssqrt_lower ( ) {
100
+ let a: Array2 < f64 > = random_hpd ( 3 ) ;
101
+ let ans = a. clone ( ) ;
102
+ let s = a. ssqrt ( UPLO :: Lower ) . unwrap ( ) ;
103
+ println ! ( "a = {:?}" , & ans) ;
104
+ println ! ( "s = {:?}" , & s) ;
105
+ assert_close_l2 ! ( & s. t( ) , & s, 1e-7 ) ;
106
+ let ss = s. dot ( & s) ;
107
+ println ! ( "ss = {:?}" , & ss) ;
108
+ assert_close_l2 ! ( & ss, & ans, 1e-7 ) ;
109
+ }
110
+
111
+ #[ test]
112
+ fn ssqrt_t_lower ( ) {
113
+ let a: Array2 < f64 > = random_hpd ( 3 ) . reversed_axes ( ) ;
114
+ let ans = a. clone ( ) ;
115
+ let s = a. ssqrt ( UPLO :: Lower ) . unwrap ( ) ;
116
+ println ! ( "a = {:?}" , & ans) ;
117
+ println ! ( "s = {:?}" , & s) ;
118
+ assert_close_l2 ! ( & s. t( ) , & s, 1e-7 ) ;
119
+ let ss = s. dot ( & s) ;
120
+ println ! ( "ss = {:?}" , & ss) ;
121
+ assert_close_l2 ! ( & ss, & ans, 1e-7 ) ;
122
+ }
0 commit comments