Skip to content

Commit b5687f8

Browse files
authored
Merge pull request #898 from SparrowLii/co_broadcast
Implement co-broadcasting in operator overloading
2 parents 5bd5891 + 03cfdfc commit b5687f8

File tree

8 files changed

+304
-29
lines changed

8 files changed

+304
-29
lines changed

src/data_traits.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ use std::ptr::NonNull;
1616
use alloc::sync::Arc;
1717
use alloc::vec::Vec;
1818

19-
use crate::{
20-
ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr,
21-
};
19+
use crate::{ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr};
2220

2321
/// Array representation trait.
2422
///
@@ -414,7 +412,6 @@ pub unsafe trait DataOwned: Data {
414412
/// Corresponding owned data with MaybeUninit elements
415413
type MaybeUninit: DataOwned<Elem = MaybeUninit<Self::Elem>>
416414
+ RawDataSubst<Self::Elem, Output=Self>;
417-
418415
#[doc(hidden)]
419416
fn new(elements: Vec<Self::Elem>) -> Self;
420417

@@ -440,6 +437,7 @@ unsafe impl<A> DataOwned for OwnedRepr<A> {
440437
fn new(elements: Vec<A>) -> Self {
441438
OwnedRepr::from(elements)
442439
}
440+
443441
fn into_shared(self) -> OwnedArcRepr<A> {
444442
OwnedArcRepr(Arc::new(self))
445443
}
@@ -622,3 +620,4 @@ impl<'a, A: 'a, B: 'a> RawDataSubst<B> for ViewRepr<&'a mut A> {
622620
ViewRepr::new()
623621
}
624622
}
623+

src/dimension/broadcast.rs

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
use crate::error::*;
2+
use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
3+
4+
/// Calculate the common shape for a pair of array shapes, that they can be broadcasted
5+
/// to. Return an error if the shapes are not compatible.
6+
///
7+
/// Uses the [NumPy broadcasting rules]
8+
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
9+
pub(crate) fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
10+
where
11+
D1: Dimension,
12+
D2: Dimension,
13+
Output: Dimension,
14+
{
15+
let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim());
16+
// Swap the order if d2 is longer.
17+
if overflow {
18+
return co_broadcast::<D2, D1, Output>(shape2, shape1);
19+
}
20+
// The output should be the same length as shape1.
21+
let mut out = Output::zeros(shape1.ndim());
22+
for (out, s) in izip!(out.slice_mut(), shape1.slice()) {
23+
*out = *s;
24+
}
25+
for (out, s2) in izip!(&mut out.slice_mut()[k..], shape2.slice()) {
26+
if *out != *s2 {
27+
if *out == 1 {
28+
*out = *s2
29+
} else if *s2 != 1 {
30+
return Err(from_kind(ErrorKind::IncompatibleShape));
31+
}
32+
}
33+
}
34+
Ok(out)
35+
}
36+
37+
pub trait DimMax<Other: Dimension> {
38+
/// The resulting dimension type after broadcasting.
39+
type Output: Dimension;
40+
}
41+
42+
/// Dimensions of the same type remain unchanged when co_broadcast.
43+
/// So you can directly use D as the resulting type.
44+
/// (Instead of <D as DimMax<D>>::BroadcastOutput)
45+
impl<D: Dimension> DimMax<D> for D {
46+
type Output = D;
47+
}
48+
49+
macro_rules! impl_broadcast_distinct_fixed {
50+
($smaller:ty, $larger:ty) => {
51+
impl DimMax<$larger> for $smaller {
52+
type Output = $larger;
53+
}
54+
55+
impl DimMax<$smaller> for $larger {
56+
type Output = $larger;
57+
}
58+
};
59+
}
60+
61+
impl_broadcast_distinct_fixed!(Ix0, Ix1);
62+
impl_broadcast_distinct_fixed!(Ix0, Ix2);
63+
impl_broadcast_distinct_fixed!(Ix0, Ix3);
64+
impl_broadcast_distinct_fixed!(Ix0, Ix4);
65+
impl_broadcast_distinct_fixed!(Ix0, Ix5);
66+
impl_broadcast_distinct_fixed!(Ix0, Ix6);
67+
impl_broadcast_distinct_fixed!(Ix1, Ix2);
68+
impl_broadcast_distinct_fixed!(Ix1, Ix3);
69+
impl_broadcast_distinct_fixed!(Ix1, Ix4);
70+
impl_broadcast_distinct_fixed!(Ix1, Ix5);
71+
impl_broadcast_distinct_fixed!(Ix1, Ix6);
72+
impl_broadcast_distinct_fixed!(Ix2, Ix3);
73+
impl_broadcast_distinct_fixed!(Ix2, Ix4);
74+
impl_broadcast_distinct_fixed!(Ix2, Ix5);
75+
impl_broadcast_distinct_fixed!(Ix2, Ix6);
76+
impl_broadcast_distinct_fixed!(Ix3, Ix4);
77+
impl_broadcast_distinct_fixed!(Ix3, Ix5);
78+
impl_broadcast_distinct_fixed!(Ix3, Ix6);
79+
impl_broadcast_distinct_fixed!(Ix4, Ix5);
80+
impl_broadcast_distinct_fixed!(Ix4, Ix6);
81+
impl_broadcast_distinct_fixed!(Ix5, Ix6);
82+
impl_broadcast_distinct_fixed!(Ix0, IxDyn);
83+
impl_broadcast_distinct_fixed!(Ix1, IxDyn);
84+
impl_broadcast_distinct_fixed!(Ix2, IxDyn);
85+
impl_broadcast_distinct_fixed!(Ix3, IxDyn);
86+
impl_broadcast_distinct_fixed!(Ix4, IxDyn);
87+
impl_broadcast_distinct_fixed!(Ix5, IxDyn);
88+
impl_broadcast_distinct_fixed!(Ix6, IxDyn);
89+
90+
91+
#[cfg(test)]
92+
#[cfg(feature = "std")]
93+
mod tests {
94+
use super::co_broadcast;
95+
use crate::{Dimension, Dim, DimMax, ShapeError, Ix0, IxDynImpl, ErrorKind};
96+
97+
#[test]
98+
fn test_broadcast_shape() {
99+
fn test_co<D1, D2>(
100+
d1: &D1,
101+
d2: &D2,
102+
r: Result<<D1 as DimMax<D2>>::Output, ShapeError>,
103+
) where
104+
D1: Dimension + DimMax<D2>,
105+
D2: Dimension,
106+
{
107+
let d = co_broadcast::<D1, D2, <D1 as DimMax<D2>>::Output>(&d1, d2);
108+
assert_eq!(d, r);
109+
}
110+
test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3])));
111+
test_co(
112+
&Dim([1, 2, 2]),
113+
&Dim([1, 3, 4]),
114+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
115+
);
116+
test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5])));
117+
let v = vec![1, 2, 3, 4, 5, 6, 7];
118+
test_co(
119+
&Dim(vec![1, 1, 3, 1, 5, 1, 7]),
120+
&Dim([2, 1, 4, 1, 6, 1]),
121+
Ok(Dim(IxDynImpl::from(v.as_slice()))),
122+
);
123+
let d = Dim([1, 2, 1, 3]);
124+
test_co(&d, &d, Ok(d));
125+
test_co(
126+
&Dim([2, 1, 2]).into_dyn(),
127+
&Dim(0),
128+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
129+
);
130+
test_co(
131+
&Dim([2, 1, 1]),
132+
&Dim([0, 0, 1, 3, 4]),
133+
Ok(Dim([0, 0, 2, 3, 4])),
134+
);
135+
test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0])));
136+
test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0])));
137+
test_co(
138+
&Dim([1, 3, 0, 1, 1]),
139+
&Dim([1, 2, 3, 1]),
140+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)),
141+
);
142+
}
143+
}

src/dimension/dimension_trait.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use super::axes_of;
1515
use super::conversion::Convert;
1616
use super::{stride_offset, stride_offset_checked};
1717
use crate::itertools::{enumerate, zip};
18-
use crate::Axis;
18+
use crate::{Axis, DimMax};
1919
use crate::IntoDimension;
2020
use crate::RemoveAxis;
2121
use crate::{ArrayView1, ArrayViewMut1};
@@ -46,6 +46,11 @@ pub trait Dimension:
4646
+ MulAssign
4747
+ for<'x> MulAssign<&'x Self>
4848
+ MulAssign<usize>
49+
+ DimMax<Ix0, Output=Self>
50+
+ DimMax<Self, Output=Self>
51+
+ DimMax<IxDyn, Output=IxDyn>
52+
+ DimMax<<Self as Dimension>::Smaller, Output=Self>
53+
+ DimMax<<Self as Dimension>::Larger, Output=<Self as Dimension>::Larger>
4954
{
5055
/// For fixed-size dimension representations (e.g. `Ix2`), this should be
5156
/// `Some(ndim)`, and for variable-size dimension representations (e.g.

src/dimension/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use num_integer::div_floor;
1212

1313
pub use self::axes::{axes_of, Axes, AxisDescription};
1414
pub use self::axis::Axis;
15+
pub use self::broadcast::DimMax;
1516
pub use self::conversion::IntoDimension;
1617
pub use self::dim::*;
1718
pub use self::dimension_trait::Dimension;
@@ -28,6 +29,7 @@ use std::mem;
2829
mod macros;
2930
mod axes;
3031
mod axis;
32+
pub(crate) mod broadcast;
3133
mod conversion;
3234
pub mod dim;
3335
mod dimension_trait;

src/impl_methods.rs

+28-3
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ use rawpointer::PointerExt;
1414

1515
use crate::imp_prelude::*;
1616

17-
use crate::arraytraits;
17+
use crate::{arraytraits, DimMax};
1818
use crate::dimension;
1919
use crate::dimension::IntoDimension;
2020
use crate::dimension::{
2121
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
2222
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2323
};
24-
use crate::error::{self, ErrorKind, ShapeError};
24+
use crate::dimension::broadcast::co_broadcast;
25+
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2526
use crate::math_cell::MathCell;
2627
use crate::itertools::zip;
2728
use crate::zip::Zip;
@@ -1766,6 +1767,28 @@ where
17661767
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
17671768
}
17681769

1770+
/// For two arrays or views, find their common shape if possible and
1771+
/// broadcast them as array views into that shape.
1772+
///
1773+
/// Return `ShapeError` if their shapes can not be broadcast together.
1774+
#[allow(clippy::type_complexity)]
1775+
pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
1776+
Result<(ArrayView<'a, A, DimMaxOf<D, E>>, ArrayView<'b, B, DimMaxOf<D, E>>), ShapeError>
1777+
where
1778+
S: Data<Elem=A>,
1779+
S2: Data<Elem=B>,
1780+
D: Dimension + DimMax<E>,
1781+
E: Dimension,
1782+
{
1783+
let shape = co_broadcast::<D, E, <D as DimMax<E>>::Output>(&self.dim, &other.dim)?;
1784+
if let Some(view1) = self.broadcast(shape.clone()) {
1785+
if let Some(view2) = other.broadcast(shape) {
1786+
return Ok((view1, view2));
1787+
}
1788+
}
1789+
Err(from_kind(ErrorKind::IncompatibleShape))
1790+
}
1791+
17691792
/// Swap axes `ax` and `bx`.
17701793
///
17711794
/// This does not move any data, it just adjusts the array’s dimensions
@@ -2013,7 +2036,7 @@ where
20132036
self.map_inplace(move |elt| *elt = x.clone());
20142037
}
20152038

2016-
fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
2039+
pub(crate) fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
20172040
where
20182041
S: DataMut,
20192042
S2: Data<Elem = B>,
@@ -2443,3 +2466,5 @@ unsafe fn unlimited_transmute<A, B>(data: A) -> B {
24432466
let old_data = ManuallyDrop::new(data);
24442467
(&*old_data as *const A as *const B).read()
24452468
}
2469+
2470+
type DimMaxOf<A, B> = <A as DimMax<B>>::Output;

0 commit comments

Comments
 (0)