diff --git a/src/free_functions.rs b/src/free_functions.rs index c1889cec8..a2ad6137c 100644 --- a/src/free_functions.rs +++ b/src/free_functions.rs @@ -9,6 +9,7 @@ use alloc::vec; #[cfg(not(feature = "std"))] use alloc::vec::Vec; +use meshgrid_impl::Meshgrid; #[allow(unused_imports)] use std::compile_error; use std::mem::{forget, size_of}; @@ -45,6 +46,8 @@ use crate::{imp_prelude::*, LayoutRef}; /// /// This macro uses `vec![]`, and has the same ownership semantics; /// elements are moved into the resulting `Array`. +/// If running with `no_std`, this may require that you `use alloc::vec` +/// before being able to use the `array!` macro. /// /// Use `array![...].into_shared()` to create an `ArcArray`. /// @@ -336,3 +339,408 @@ pub fn rcarr3(xs: &[[[A; M]; N]]) -> A { arr3(xs).into_shared() } + +/// The indexing order for [`meshgrid`]; see there for more details. +/// +/// Controls whether the first argument to `meshgrid` will fill the rows or columns of the outputs. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MeshIndex +{ + /// Cartesian indexing. + /// + /// The first argument of `meshgrid` will repeat over the columns of the output. + /// + /// Note: this is the default in `numpy`. + XY, + /// Matrix indexing. + /// + /// The first argument of `meshgrid` will repeat over the rows of the output. + IJ, +} + +mod meshgrid_impl +{ + use super::MeshIndex; + use crate::extension::nonnull::nonnull_debug_checked_from_ptr; + use crate::{ + ArrayBase, + ArrayRef1, + ArrayView, + ArrayView2, + ArrayView3, + ArrayView4, + ArrayView5, + ArrayView6, + Axis, + Data, + Dim, + IntoDimension, + Ix1, + LayoutRef1, + }; + + /// Construct the correct strides for the `idx`-th entry into meshgrid + fn construct_strides( + arr: &LayoutRef1, idx: usize, indexing: MeshIndex, + ) -> <[usize; N] as IntoDimension>::Dim + where [usize; N]: IntoDimension + { + let mut ret = [0; N]; + if idx < 2 && indexing == MeshIndex::XY { + ret[1 - idx] = arr.stride_of(Axis(0)) as usize; + } else { + ret[idx] = arr.stride_of(Axis(0)) as usize; + } + Dim(ret) + } + + /// Construct the correct shape for the `idx`-th entry into meshgrid + fn construct_shape( + arrays: [&LayoutRef1; N], indexing: MeshIndex, + ) -> <[usize; N] as IntoDimension>::Dim + where [usize; N]: IntoDimension + { + let mut ret = arrays.map(|a| a.len()); + if indexing == MeshIndex::XY { + ret.swap(0, 1); + } + Dim(ret) + } + + /// A trait to encapsulate static dispatch for [`meshgrid`](super::meshgrid); see there for more details. + /// + /// The inputs should always be some sort of 1D array. + /// The outputs should always be ND arrays where N is the number of inputs. + /// + /// Where possible, this trait tries to return array views rather than allocating additional memory. + pub trait Meshgrid + { + type Output; + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output; + } + + macro_rules! meshgrid_body { + ($count:literal, $indexing:expr, $(($arr:expr, $idx:literal)),+) => { + { + let shape = construct_shape([$($arr),+], $indexing); + ( + $({ + let strides = construct_strides::<_, $count>($arr, $idx, $indexing); + unsafe { ArrayView::new(nonnull_debug_checked_from_ptr($arr.as_ptr() as *mut A), shape, strides) } + }),+ + ) + } + }; + } + + impl<'a, 'b, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1) + { + type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(2, indexing, (arrays.0, 0), (arrays.1, 1)) + } + } + + impl<'a, 'b, S1, S2, A: 'b + 'a> Meshgrid for (&'a ArrayBase, &'b ArrayBase) + where + S1: Data, + S2: Data, + { + type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1), indexing) + } + } + + impl<'a, 'b, 'c, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1) + { + type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(3, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2)) + } + } + + impl<'a, 'b, 'c, S1, S2, S3, A: 'b + 'a + 'c> Meshgrid + for (&'a ArrayBase, &'b ArrayBase, &'c ArrayBase) + where + S1: Data, + S2: Data, + S3: Data, + { + type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2), indexing) + } + } + + impl<'a, 'b, 'c, 'd, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1, &'d ArrayRef1) + { + type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(4, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3)) + } + } + + impl<'a, 'b, 'c, 'd, S1, S2, S3, S4, A: 'a + 'b + 'c + 'd> Meshgrid + for (&'a ArrayBase, &'b ArrayBase, &'c ArrayBase, &'d ArrayBase) + where + S1: Data, + S2: Data, + S3: Data, + S4: Data, + { + type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3), indexing) + } + } + + impl<'a, 'b, 'c, 'd, 'e, A> Meshgrid + for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1, &'d ArrayRef1, &'e ArrayRef1) + { + type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(5, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4)) + } + } + + impl<'a, 'b, 'c, 'd, 'e, S1, S2, S3, S4, S5, A: 'a + 'b + 'c + 'd + 'e> Meshgrid + for ( + &'a ArrayBase, + &'b ArrayBase, + &'c ArrayBase, + &'d ArrayBase, + &'e ArrayBase, + ) + where + S1: Data, + S2: Data, + S3: Data, + S4: Data, + S5: Data, + { + type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4), indexing) + } + } + + impl<'a, 'b, 'c, 'd, 'e, 'f, A> Meshgrid + for ( + &'a ArrayRef1, + &'b ArrayRef1, + &'c ArrayRef1, + &'d ArrayRef1, + &'e ArrayRef1, + &'f ArrayRef1, + ) + { + type Output = ( + ArrayView6<'a, A>, + ArrayView6<'b, A>, + ArrayView6<'c, A>, + ArrayView6<'d, A>, + ArrayView6<'e, A>, + ArrayView6<'f, A>, + ); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(6, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4), (arrays.5, 5)) + } + } + + impl<'a, 'b, 'c, 'd, 'e, 'f, S1, S2, S3, S4, S5, S6, A: 'a + 'b + 'c + 'd + 'e + 'f> Meshgrid + for ( + &'a ArrayBase, + &'b ArrayBase, + &'c ArrayBase, + &'d ArrayBase, + &'e ArrayBase, + &'f ArrayBase, + ) + where + S1: Data, + S2: Data, + S3: Data, + S4: Data, + S5: Data, + S6: Data, + { + type Output = ( + ArrayView6<'a, A>, + ArrayView6<'b, A>, + ArrayView6<'c, A>, + ArrayView6<'d, A>, + ArrayView6<'e, A>, + ArrayView6<'f, A>, + ); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4, &**arrays.5), indexing) + } + } +} + +/// Create coordinate matrices from coordinate vectors. +/// +/// Given an N-tuple of 1D coordinate vectors, return an N-tuple of ND coordinate arrays. +/// This is particularly useful for computing the outputs of functions with N arguments over +/// regularly spaced grids. +/// +/// The `indexing` argument can be controlled by [`MeshIndex`] to support both Cartesian and +/// matrix indexing. In the two-dimensional case, inputs of length `N` and `M` will create +/// output arrays of size `(M, N)` when using [`MeshIndex::XY`] and size `(N, M)` when using +/// [`MeshIndex::IJ`]. +/// +/// # Example +/// ``` +/// use ndarray::{array, meshgrid, MeshIndex}; +/// +/// let arr1 = array![1, 2]; +/// let arr2 = array![3, 4]; +/// let arr3 = array![5, 6]; +/// +/// // Cartesian indexing +/// let (res1, res2) = meshgrid((&arr1, &arr2), MeshIndex::XY); +/// assert_eq!(res1, array![ +/// [1, 2], +/// [1, 2], +/// ]); +/// assert_eq!(res2, array![ +/// [3, 3], +/// [4, 4], +/// ]); +/// +/// // Matrix indexing +/// let (res1, res2) = meshgrid((&arr1, &arr2), MeshIndex::IJ); +/// assert_eq!(res1, array![ +/// [1, 1], +/// [2, 2], +/// ]); +/// assert_eq!(res2, array![ +/// [3, 4], +/// [3, 4], +/// ]); +/// +/// let (_, _, res3) = meshgrid((&arr1, &arr2, &arr3), MeshIndex::XY); +/// assert_eq!(res3, array![ +/// [[5, 6], +/// [5, 6]], +/// [[5, 6], +/// [5, 6]], +/// ]); +/// ``` +pub fn meshgrid(arrays: T, indexing: MeshIndex) -> T::Output +{ + Meshgrid::meshgrid(arrays, indexing) +} + +#[cfg(test)] +mod tests +{ + use super::s; + use crate::{meshgrid, Axis, MeshIndex}; + #[cfg(not(feature = "std"))] + use alloc::vec; + + #[test] + fn test_meshgrid2() + { + let x = array![1, 2, 3]; + let y = array![4, 5, 6, 7]; + let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); + assert_eq!(xx, array![[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]); + assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7]]); + + let (xx, yy) = meshgrid((&x, &y), MeshIndex::IJ); + assert_eq!(xx, array![[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]); + assert_eq!(yy, array![[4, 5, 6, 7], [4, 5, 6, 7], [4, 5, 6, 7]]); + } + + #[test] + fn test_meshgrid3() + { + let x = array![1, 2, 3]; + let y = array![4, 5, 6, 7]; + let z = array![-1, -2]; + let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::XY); + assert_eq!(xx, array![ + [[1, 1], [2, 2], [3, 3]], + [[1, 1], [2, 2], [3, 3]], + [[1, 1], [2, 2], [3, 3]], + [[1, 1], [2, 2], [3, 3]], + ]); + assert_eq!(yy, array![ + [[4, 4], [4, 4], [4, 4]], + [[5, 5], [5, 5], [5, 5]], + [[6, 6], [6, 6], [6, 6]], + [[7, 7], [7, 7], [7, 7]], + ]); + assert_eq!(zz, array![ + [[-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2]], + ]); + + let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::IJ); + assert_eq!(xx, array![ + [[1, 1], [1, 1], [1, 1], [1, 1]], + [[2, 2], [2, 2], [2, 2], [2, 2]], + [[3, 3], [3, 3], [3, 3], [3, 3]], + ]); + assert_eq!(yy, array![ + [[4, 4], [5, 5], [6, 6], [7, 7]], + [[4, 4], [5, 5], [6, 6], [7, 7]], + [[4, 4], [5, 5], [6, 6], [7, 7]], + ]); + assert_eq!(zz, array![ + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], + ]); + } + + #[test] + fn test_meshgrid_from_offset() + { + let x = array![1, 2, 3]; + let x = x.slice(s![1..]); + let y = array![4, 5, 6]; + let y = y.slice(s![1..]); + let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); + assert_eq!(xx, array![[2, 3], [2, 3]]); + assert_eq!(yy, array![[5, 5], [6, 6]]); + } + + #[test] + fn test_meshgrid_neg_stride() + { + let x = array![1, 2, 3]; + let x = x.slice(s![..;-1]); + assert!(x.stride_of(Axis(0)) < 0); // Setup for test + let y = array![4, 5, 6]; + let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); + assert_eq!(xx, array![[3, 2, 1], [3, 2, 1], [3, 2, 1]]); + assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6]]); + } +}