Skip to content

Commit 3d9d9f4

Browse files
committed
feat: add eye method to matrix
1 parent 3b23684 commit 3d9d9f4

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

src/matrix.rs

+23
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ impl<T: Num + PartialOrd + Copy> DynamicMatrix<T> {
3131
let data = vec![value; shape.size()];
3232
DynamicMatrix::new(shape, &data)
3333
}
34+
pub fn eye(shape: &Shape) -> Result<DynamicMatrix<T>, ShapeError> {
35+
let mut result = DynamicMatrix::zeros(shape).unwrap();
36+
for i in 0..shape[0] {
37+
result.set(&coord![i, i], T::one()).unwrap();
38+
}
39+
Ok(result)
40+
}
3441
pub fn zeros(shape: &Shape) -> Result<DynamicMatrix<T>, ShapeError> { Self::fill(shape, T::zero()) }
3542
pub fn ones(shape: &Shape) -> Result<DynamicMatrix<T>, ShapeError> { Self::fill(shape, T::one()) }
3643

@@ -277,6 +284,22 @@ mod tests {
277284
assert_eq!(matrix[coord![1, 1]], 3.0);
278285
}
279286

287+
#[test]
288+
fn test_eye() {
289+
let shape = shape![3, 3].unwrap();
290+
let matrix = DynamicMatrix::<f64>::eye(&shape).unwrap();
291+
assert_eq!(matrix.shape(), &shape);
292+
assert_eq!(matrix[coord![0, 0]], 1.0);
293+
assert_eq!(matrix[coord![0, 1]], 0.0);
294+
assert_eq!(matrix[coord![0, 2]], 0.0);
295+
assert_eq!(matrix[coord![1, 0]], 0.0);
296+
assert_eq!(matrix[coord![1, 1]], 1.0);
297+
assert_eq!(matrix[coord![1, 2]], 0.0);
298+
assert_eq!(matrix[coord![2, 0]], 0.0);
299+
assert_eq!(matrix[coord![2, 1]], 0.0);
300+
assert_eq!(matrix[coord![2, 2]], 1.0);
301+
}
302+
280303
#[test]
281304
fn test_zeros() {
282305
let shape = shape![2, 2].unwrap();

0 commit comments

Comments
 (0)