diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 6fb850b50..524fe2d14 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -137,6 +137,21 @@ where S: Data } self.dot_generic(rhs) } + + /// Outer product of two 1D arrays. + /// + /// The outer product of two vectors a (of dimension M) and b (of dimension N) + /// is defined as an (M*N)-dimensional matrix whose ij-th element is a_i * b_j. + /// This implementation essentially calls `dot` by reshaping the vectors. + pub fn outer(&self, b: &ArrayBase) -> Array + where + S2: Data, + A: LinalgScalar, + { + let (size_a, size_b) = (self.shape()[0], b.shape()[0]); + let b_reshaped = b.view().into_shape((1, size_b)).unwrap(); + self.view().into_shape((size_a, 1)).unwrap().dot(&b_reshaped) + } } /// Return a pointer to the starting element in BLAS's view. diff --git a/tests/array.rs b/tests/array.rs index 8f01d0636..660ff4a67 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -81,6 +81,13 @@ fn test_mat_mul() assert_eq!(c.dot(&a), a); } +#[test] +fn test_outer_product() { + let a: Array1 = array![2., 4.]; + let b: Array1 = array![3., 5.]; + assert_eq!(a.outer(&b), array![[6., 10.], [12., 20.]]); +} + #[deny(unsafe_code)] #[test] fn test_slice()