@@ -31,6 +31,13 @@ impl<T: Num + PartialOrd + Copy> DynamicMatrix<T> {
31
31
let data = vec ! [ value; shape. size( ) ] ;
32
32
DynamicMatrix :: new ( shape, & data)
33
33
}
34
+ pub fn eye ( shape : & Shape ) -> Result < DynamicMatrix < T > , ShapeError > {
35
+ let mut result = DynamicMatrix :: zeros ( shape) . unwrap ( ) ;
36
+ for i in 0 ..shape[ 0 ] {
37
+ result. set ( & coord ! [ i, i] , T :: one ( ) ) . unwrap ( ) ;
38
+ }
39
+ Ok ( result)
40
+ }
34
41
pub fn zeros ( shape : & Shape ) -> Result < DynamicMatrix < T > , ShapeError > { Self :: fill ( shape, T :: zero ( ) ) }
35
42
pub fn ones ( shape : & Shape ) -> Result < DynamicMatrix < T > , ShapeError > { Self :: fill ( shape, T :: one ( ) ) }
36
43
@@ -277,6 +284,22 @@ mod tests {
277
284
assert_eq ! ( matrix[ coord![ 1 , 1 ] ] , 3.0 ) ;
278
285
}
279
286
287
+ #[ test]
288
+ fn test_eye ( ) {
289
+ let shape = shape ! [ 3 , 3 ] . unwrap ( ) ;
290
+ let matrix = DynamicMatrix :: < f64 > :: eye ( & shape) . unwrap ( ) ;
291
+ assert_eq ! ( matrix. shape( ) , & shape) ;
292
+ assert_eq ! ( matrix[ coord![ 0 , 0 ] ] , 1.0 ) ;
293
+ assert_eq ! ( matrix[ coord![ 0 , 1 ] ] , 0.0 ) ;
294
+ assert_eq ! ( matrix[ coord![ 0 , 2 ] ] , 0.0 ) ;
295
+ assert_eq ! ( matrix[ coord![ 1 , 0 ] ] , 0.0 ) ;
296
+ assert_eq ! ( matrix[ coord![ 1 , 1 ] ] , 1.0 ) ;
297
+ assert_eq ! ( matrix[ coord![ 1 , 2 ] ] , 0.0 ) ;
298
+ assert_eq ! ( matrix[ coord![ 2 , 0 ] ] , 0.0 ) ;
299
+ assert_eq ! ( matrix[ coord![ 2 , 1 ] ] , 0.0 ) ;
300
+ assert_eq ! ( matrix[ coord![ 2 , 2 ] ] , 1.0 ) ;
301
+ }
302
+
280
303
#[ test]
281
304
fn test_zeros ( ) {
282
305
let shape = shape ! [ 2 , 2 ] . unwrap ( ) ;
0 commit comments