Skip to content

Commit 6c40d61

Browse files
committed
Add function broadcast_with
1 parent 27c2059 commit 6c40d61

File tree

3 files changed

+66
-11
lines changed

3 files changed

+66
-11
lines changed

src/impl_methods.rs

+32-2
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use rawpointer::PointerExt;
1313

1414
use crate::imp_prelude::*;
1515

16-
use crate::arraytraits;
16+
use crate::{arraytraits, BroadcastShape};
1717
use crate::dimension;
1818
use crate::dimension::IntoDimension;
1919
use crate::dimension::{
2020
abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked,
2121
stride_offset, Axes,
2222
};
23-
use crate::error::{self, ErrorKind, ShapeError};
23+
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2424
use crate::math_cell::MathCell;
2525
use crate::itertools::zip;
2626
use crate::zip::Zip;
@@ -1707,6 +1707,36 @@ where
17071707
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
17081708
}
17091709

1710+
/// Calculate the views of two ArrayBases after broadcasting each other, if possible.
1711+
///
1712+
/// Return `ShapeError` if their shapes can not be broadcast together.
1713+
///
1714+
/// ```
1715+
/// use ndarray::{arr1, arr2};
1716+
///
1717+
/// let a = arr2(&[[2], [3], [4]]);
1718+
/// let b = arr1(&[5, 6, 7]);
1719+
/// let (a1, b1) = a.broadcast_with(&b).unwrap();
1720+
/// assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
1721+
/// assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));
1722+
/// ```
1723+
pub fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
1724+
Result<(ArrayView<'a, A, <D as BroadcastShape<E>>::Output>, ArrayView<'b, B, <D as BroadcastShape<E>>::Output>), ShapeError>
1725+
where
1726+
S: Data<Elem=A>,
1727+
S2: Data<Elem=B>,
1728+
D: Dimension + BroadcastShape<E>,
1729+
E: Dimension,
1730+
{
1731+
let shape = self.dim.broadcast_shape(&other.dim)?;
1732+
if let Some(view1) = self.broadcast(shape.clone()) {
1733+
if let Some(view2) = other.broadcast(shape) {
1734+
return Ok((view1, view2))
1735+
}
1736+
}
1737+
return Err(from_kind(ErrorKind::IncompatibleShape));
1738+
}
1739+
17101740
/// Swap axes `ax` and `bx`.
17111741
///
17121742
/// This does not move any data, it just adjusts the array’s dimensions

src/impl_ops.rs

+3-9
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ where
106106
out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
107107
out
108108
} else {
109-
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
110-
let lhs = self.broadcast(shape.clone()).unwrap();
111-
let rhs = rhs.broadcast(shape).unwrap();
109+
let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
112110
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
113111
}
114112
}
@@ -143,9 +141,7 @@ where
143141
out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
144142
out
145143
} else {
146-
let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
147-
let lhs = self.broadcast(shape.clone()).unwrap();
148-
let rhs = rhs.broadcast(shape).unwrap();
144+
let (rhs, lhs) = rhs.broadcast_with(self).unwrap();
149145
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
150146
}
151147
}
@@ -171,9 +167,7 @@ where
171167
{
172168
type Output = Array<A, <D as BroadcastShape<E>>::Output>;
173169
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
174-
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
175-
let lhs = self.broadcast(shape.clone()).unwrap();
176-
let rhs = rhs.broadcast(shape).unwrap();
170+
let (lhs, rhs) = self.broadcast_with(rhs).unwrap();
177171
Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth))
178172
}
179173
}

tests/broadcast.rs

+31
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use ndarray::prelude::*;
2+
use ndarray::{ShapeError, ErrorKind, arr3};
23

34
#[test]
45
#[cfg(feature = "std")]
@@ -81,3 +82,33 @@ fn test_broadcast_1d() {
8182
println!("b2=\n{:?}", b2);
8283
assert_eq!(b0, b2);
8384
}
85+
86+
#[test]
87+
fn test_broadcast_with() {
88+
let a = arr2(&[[1., 2.], [3., 4.]]);
89+
let b = aview0(&1.);
90+
let (a1, b1) = a.broadcast_with(&b).unwrap();
91+
assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]]));
92+
assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]]));
93+
94+
let a = arr2(&[[2], [3], [4]]);
95+
let b = arr1(&[5, 6, 7]);
96+
let (a1, b1) = a.broadcast_with(&b).unwrap();
97+
assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
98+
assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));
99+
100+
// Negative strides and non-contiguous memory
101+
let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
102+
let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap();
103+
let a = s.slice(s![..;-1,..;2,..]);
104+
let b = s.slice(s![..2, -1, ..]);
105+
let (a1, b1) = a.broadcast_with(&b).unwrap();
106+
assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]]));
107+
assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]]));
108+
109+
// ShapeError
110+
let a = arr2(&[[2, 2], [3, 3], [4, 4]]);
111+
let b = arr1(&[5, 6, 7]);
112+
let e = a.broadcast_with(&b);
113+
assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
114+
}

0 commit comments

Comments
 (0)