|
| 1 | +use crate::error::*; |
| 2 | +use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; |
| 3 | + |
| 4 | +/// Calculate the common shape for a pair of array shapes, that they can be broadcasted |
| 5 | +/// to. Return an error if the shapes are not compatible. |
| 6 | +/// |
| 7 | +/// Uses the [NumPy broadcasting rules] |
| 8 | +// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). |
| 9 | +pub(crate) fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError> |
| 10 | +where |
| 11 | + D1: Dimension, |
| 12 | + D2: Dimension, |
| 13 | + Output: Dimension, |
| 14 | +{ |
| 15 | + let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim()); |
| 16 | + // Swap the order if d2 is longer. |
| 17 | + if overflow { |
| 18 | + return co_broadcast::<D2, D1, Output>(shape2, shape1); |
| 19 | + } |
| 20 | + // The output should be the same length as shape1. |
| 21 | + let mut out = Output::zeros(shape1.ndim()); |
| 22 | + for (out, s) in izip!(out.slice_mut(), shape1.slice()) { |
| 23 | + *out = *s; |
| 24 | + } |
| 25 | + for (out, s2) in izip!(&mut out.slice_mut()[k..], shape2.slice()) { |
| 26 | + if *out != *s2 { |
| 27 | + if *out == 1 { |
| 28 | + *out = *s2 |
| 29 | + } else if *s2 != 1 { |
| 30 | + return Err(from_kind(ErrorKind::IncompatibleShape)); |
| 31 | + } |
| 32 | + } |
| 33 | + } |
| 34 | + Ok(out) |
| 35 | +} |
| 36 | + |
| 37 | +pub trait DimMax<Other: Dimension> { |
| 38 | + /// The resulting dimension type after broadcasting. |
| 39 | + type Output: Dimension; |
| 40 | +} |
| 41 | + |
| 42 | +/// Dimensions of the same type remain unchanged when co_broadcast. |
| 43 | +/// So you can directly use D as the resulting type. |
| 44 | +/// (Instead of <D as DimMax<D>>::BroadcastOutput) |
| 45 | +impl<D: Dimension> DimMax<D> for D { |
| 46 | + type Output = D; |
| 47 | +} |
| 48 | + |
| 49 | +macro_rules! impl_broadcast_distinct_fixed { |
| 50 | + ($smaller:ty, $larger:ty) => { |
| 51 | + impl DimMax<$larger> for $smaller { |
| 52 | + type Output = $larger; |
| 53 | + } |
| 54 | + |
| 55 | + impl DimMax<$smaller> for $larger { |
| 56 | + type Output = $larger; |
| 57 | + } |
| 58 | + }; |
| 59 | +} |
| 60 | + |
| 61 | +impl_broadcast_distinct_fixed!(Ix0, Ix1); |
| 62 | +impl_broadcast_distinct_fixed!(Ix0, Ix2); |
| 63 | +impl_broadcast_distinct_fixed!(Ix0, Ix3); |
| 64 | +impl_broadcast_distinct_fixed!(Ix0, Ix4); |
| 65 | +impl_broadcast_distinct_fixed!(Ix0, Ix5); |
| 66 | +impl_broadcast_distinct_fixed!(Ix0, Ix6); |
| 67 | +impl_broadcast_distinct_fixed!(Ix1, Ix2); |
| 68 | +impl_broadcast_distinct_fixed!(Ix1, Ix3); |
| 69 | +impl_broadcast_distinct_fixed!(Ix1, Ix4); |
| 70 | +impl_broadcast_distinct_fixed!(Ix1, Ix5); |
| 71 | +impl_broadcast_distinct_fixed!(Ix1, Ix6); |
| 72 | +impl_broadcast_distinct_fixed!(Ix2, Ix3); |
| 73 | +impl_broadcast_distinct_fixed!(Ix2, Ix4); |
| 74 | +impl_broadcast_distinct_fixed!(Ix2, Ix5); |
| 75 | +impl_broadcast_distinct_fixed!(Ix2, Ix6); |
| 76 | +impl_broadcast_distinct_fixed!(Ix3, Ix4); |
| 77 | +impl_broadcast_distinct_fixed!(Ix3, Ix5); |
| 78 | +impl_broadcast_distinct_fixed!(Ix3, Ix6); |
| 79 | +impl_broadcast_distinct_fixed!(Ix4, Ix5); |
| 80 | +impl_broadcast_distinct_fixed!(Ix4, Ix6); |
| 81 | +impl_broadcast_distinct_fixed!(Ix5, Ix6); |
| 82 | +impl_broadcast_distinct_fixed!(Ix0, IxDyn); |
| 83 | +impl_broadcast_distinct_fixed!(Ix1, IxDyn); |
| 84 | +impl_broadcast_distinct_fixed!(Ix2, IxDyn); |
| 85 | +impl_broadcast_distinct_fixed!(Ix3, IxDyn); |
| 86 | +impl_broadcast_distinct_fixed!(Ix4, IxDyn); |
| 87 | +impl_broadcast_distinct_fixed!(Ix5, IxDyn); |
| 88 | +impl_broadcast_distinct_fixed!(Ix6, IxDyn); |
| 89 | + |
| 90 | + |
| 91 | +#[cfg(test)] |
| 92 | +#[cfg(feature = "std")] |
| 93 | +mod tests { |
| 94 | + use super::co_broadcast; |
| 95 | + use crate::{Dimension, Dim, DimMax, ShapeError, Ix0, IxDynImpl, ErrorKind}; |
| 96 | + |
| 97 | + #[test] |
| 98 | + fn test_broadcast_shape() { |
| 99 | + fn test_co<D1, D2>( |
| 100 | + d1: &D1, |
| 101 | + d2: &D2, |
| 102 | + r: Result<<D1 as DimMax<D2>>::Output, ShapeError>, |
| 103 | + ) where |
| 104 | + D1: Dimension + DimMax<D2>, |
| 105 | + D2: Dimension, |
| 106 | + { |
| 107 | + let d = co_broadcast::<D1, D2, <D1 as DimMax<D2>>::Output>(&d1, d2); |
| 108 | + assert_eq!(d, r); |
| 109 | + } |
| 110 | + test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); |
| 111 | + test_co( |
| 112 | + &Dim([1, 2, 2]), |
| 113 | + &Dim([1, 3, 4]), |
| 114 | + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), |
| 115 | + ); |
| 116 | + test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); |
| 117 | + let v = vec![1, 2, 3, 4, 5, 6, 7]; |
| 118 | + test_co( |
| 119 | + &Dim(vec![1, 1, 3, 1, 5, 1, 7]), |
| 120 | + &Dim([2, 1, 4, 1, 6, 1]), |
| 121 | + Ok(Dim(IxDynImpl::from(v.as_slice()))), |
| 122 | + ); |
| 123 | + let d = Dim([1, 2, 1, 3]); |
| 124 | + test_co(&d, &d, Ok(d)); |
| 125 | + test_co( |
| 126 | + &Dim([2, 1, 2]).into_dyn(), |
| 127 | + &Dim(0), |
| 128 | + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), |
| 129 | + ); |
| 130 | + test_co( |
| 131 | + &Dim([2, 1, 1]), |
| 132 | + &Dim([0, 0, 1, 3, 4]), |
| 133 | + Ok(Dim([0, 0, 2, 3, 4])), |
| 134 | + ); |
| 135 | + test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); |
| 136 | + test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); |
| 137 | + test_co( |
| 138 | + &Dim([1, 3, 0, 1, 1]), |
| 139 | + &Dim([1, 2, 3, 1]), |
| 140 | + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)), |
| 141 | + ); |
| 142 | + } |
| 143 | +} |
0 commit comments