diff --git a/src/lib.rs b/src/lib.rs index 718f954..def49e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,4 +6,5 @@ pub mod matrix; pub mod shape; pub mod storage; pub mod tensor; +pub mod traits; pub mod vector; diff --git a/src/matrix.rs b/src/matrix.rs index 215be90..c1dac8d 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -7,6 +7,7 @@ use crate::error::ShapeError; use crate::shape; use crate::shape::Shape; use crate::tensor::DynamicTensor; +use crate::traits::MatMul; use crate::vector::DynamicVector; use num::{Float, Num}; @@ -71,25 +72,18 @@ impl DynamicMatrix { let result = self.tensor.min(axes); DynamicVector::from_tensor(result).unwrap() } +} - // Vector/Matrix Multiplication - pub fn matmul(&self, rhs: &Self) -> DynamicMatrix { - // Matrix-Matrix multiplication - assert_eq!(self.shape()[1], rhs.shape()[0]); - let mut result = DynamicTensor::zeros(&shape![self.shape()[0], rhs.shape()[1]].unwrap()); - for i in 0..self.shape()[0] { - for j in 0..rhs.shape()[1] { - let mut sum = T::zero(); - for k in 0..self.shape()[1] { - sum = sum + self[coord![i, k].unwrap()] * rhs[coord![k, j].unwrap()]; - } - result.set(&coord![i, j].unwrap(), sum).unwrap(); - } - } - DynamicMatrix::from_tensor(result).unwrap() +impl DynamicMatrix { + pub fn pow(&self, power: T) -> DynamicMatrix { + DynamicMatrix::from_tensor(self.tensor.pow(power)).unwrap() } +} + +impl MatMul> for DynamicMatrix { + type Output = DynamicVector; - pub fn vecmul(&self, rhs: &DynamicVector) -> DynamicVector { + fn matmul(self, rhs: &DynamicVector) -> DynamicVector { assert_eq!(self.shape()[1], rhs.shape()[0]); let mut result = DynamicTensor::zeros(&shape![self.shape()[0]].unwrap()); for i in 0..self.shape()[0] { @@ -103,12 +97,24 @@ impl DynamicMatrix { } } -impl DynamicMatrix { - pub fn pow(&self, power: T) -> DynamicMatrix { - DynamicMatrix::from_tensor(self.tensor.pow(power)).unwrap() +impl MatMul> for DynamicMatrix { + type Output = DynamicMatrix; + + fn matmul(self, rhs: &DynamicMatrix) -> DynamicMatrix { + assert_eq!(self.shape()[1], rhs.shape()[0]); + let mut result = DynamicTensor::zeros(&shape![self.shape()[0], rhs.shape()[1]].unwrap()); + for i in 0..self.shape()[0] { + for j in 0..rhs.shape()[1] { + let mut sum = T::zero(); + for k in 0..self.shape()[1] { + sum = sum + self[coord![i, k].unwrap()] * rhs[coord![k, j].unwrap()]; + } + result.set(&coord![i, j].unwrap(), sum).unwrap(); + } + } + DynamicMatrix::from_tensor(result).unwrap() } } - // Scalar Addition impl Add for DynamicMatrix { type Output = DynamicMatrix; @@ -420,7 +426,7 @@ mod tests { } #[test] - fn test_matmul() { + fn test_matmul_mat() { let shape = shape![2, 2].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; @@ -435,13 +441,13 @@ mod tests { } #[test] - fn test_vecmul() { + fn test_matmul_vec() { let shape = shape![2, 2].unwrap(); let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let vector_data = vec![1.0, 2.0]; let vector = DynamicVector::new(&vector_data).unwrap(); - let result = matrix.vecmul(&vector); + let result = matrix.matmul(&vector); assert_eq!(result.shape(), &shape![2].unwrap()); assert_eq!(result[0], 5.0); assert_eq!(result[1], 11.0); diff --git a/src/traits.rs b/src/traits.rs new file mode 100644 index 0000000..d2a8d1d --- /dev/null +++ b/src/traits.rs @@ -0,0 +1,5 @@ +pub trait MatMul { + type Output; + + fn matmul(self, rhs: &Rhs) -> Self::Output; +} diff --git a/src/vector.rs b/src/vector.rs index ff2b939..9c7a627 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -6,6 +6,7 @@ use crate::matrix::DynamicMatrix; use crate::shape; use crate::shape::Shape; use crate::tensor::DynamicTensor; +use crate::traits::MatMul; use num::Float; use num::Num; @@ -66,9 +67,18 @@ impl DynamicVector { let result = self.tensor.min(vec![]); DynamicVector::from_tensor(result).unwrap() } +} + +impl DynamicVector { + pub fn pow(&self, power: T) -> DynamicVector { + DynamicVector::from_tensor(self.tensor.pow(power)).unwrap() + } +} + +impl MatMul> for DynamicVector { + type Output = DynamicVector; - // Vector/Matrix Multiplication - pub fn vecmul(&self, rhs: &DynamicVector) -> DynamicVector { + fn matmul(self, rhs: &DynamicVector) -> DynamicVector { assert!(self.shape() == rhs.shape()); let mut result = T::zero(); for i in 0..self.size() { @@ -76,8 +86,12 @@ impl DynamicVector { } DynamicVector::new(&[result]).unwrap() } +} + +impl MatMul> for DynamicVector { + type Output = DynamicVector; - pub fn matmul(&self, rhs: &DynamicMatrix) -> DynamicVector { + fn matmul(self, rhs: &DynamicMatrix) -> DynamicVector { assert_eq!(self.shape()[0], rhs.shape()[0]); let mut result = DynamicTensor::zeros(&shape![rhs.shape()[1]].unwrap()); for j in 0..rhs.shape()[1] { @@ -91,12 +105,6 @@ impl DynamicVector { } } -impl DynamicVector { - pub fn pow(&self, power: T) -> DynamicVector { - DynamicVector::from_tensor(self.tensor.pow(power)).unwrap() - } -} - // Scalar Addition impl Add for DynamicVector { type Output = DynamicVector; @@ -376,18 +384,18 @@ mod tests { } #[test] - fn test_vecmul() { + fn test_matmul_vec() { let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); - let result = vector1.vecmul(&vector2); + let result = vector1.matmul(&vector2); assert_eq!(result[0], 40.0); assert_eq!(result.shape(), &shape![1].unwrap()); } #[test] - fn test_matmul() { + fn test_matmul_mat() { let data_vector = vec![1.0, 2.0]; let data_matrix = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data_vector).unwrap();