Skip to content

Commit f0e9acf

Browse files
committed
Change Mask's T to the Simd element type
1 parent 32ba8ed commit f0e9acf

File tree

10 files changed

+363
-247
lines changed

10 files changed

+363
-247
lines changed

crates/core_simd/src/masks.rs

Lines changed: 180 additions & 68 deletions
Large diffs are not rendered by default.

crates/core_simd/src/select.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
use crate::simd::{
2-
FixEndianness, LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount,
3-
};
1+
use crate::simd::{FixEndianness, LaneCount, Mask, Simd, SimdElement, SupportedLaneCount};
42

53
/// Choose elements from two vectors using a mask.
64
///
@@ -58,7 +56,7 @@ pub trait Select<T> {
5856
impl<T, U, const N: usize> Select<Simd<T, N>> for Mask<U, N>
5957
where
6058
T: SimdElement,
61-
U: MaskElement,
59+
U: SimdElement,
6260
LaneCount<N>: SupportedLaneCount,
6361
{
6462
#[inline]
@@ -133,14 +131,19 @@ where
133131

134132
impl<T, U, const N: usize> Select<Mask<T, N>> for Mask<U, N>
135133
where
136-
T: MaskElement,
137-
U: MaskElement,
134+
T: SimdElement,
135+
U: SimdElement,
138136
LaneCount<N>: SupportedLaneCount,
139137
{
140138
#[inline]
141139
fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
142-
let selected: Simd<T, N> =
143-
Select::select(self, true_values.to_simd(), false_values.to_simd());
140+
// Safety:
141+
// simd_as between masks is always safe (they're vectors of ints).
142+
// simd_select uses a mask that matches the width and number of elements
143+
let selected: Simd<T::Mask, N> = unsafe {
144+
let mask: Simd<T::Mask, N> = core::intrinsics::simd::simd_as(self.to_simd());
145+
core::intrinsics::simd::simd_select(mask, true_values.to_simd(), false_values.to_simd())
146+
};
144147

145148
// Safety: all values come from masks
146149
unsafe { Mask::from_simd_unchecked(selected) }
@@ -149,12 +152,12 @@ where
149152

150153
impl<T, const N: usize> Select<Mask<T, N>> for u64
151154
where
152-
T: MaskElement,
155+
T: SimdElement,
153156
LaneCount<N>: SupportedLaneCount,
154157
{
155158
#[inline]
156159
fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
157-
let selected: Simd<T, N> =
160+
let selected: Simd<T::Mask, N> =
158161
Select::select(self, true_values.to_simd(), false_values.to_simd());
159162

160163
// Safety: all values come from masks

crates/core_simd/src/simd/cmp/eq.rs

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ macro_rules! impl_number {
2424
where
2525
LaneCount<N>: SupportedLaneCount,
2626
{
27-
type Mask = Mask<<$number as SimdElement>::Mask, N>;
27+
type Mask = Mask<$number, N>;
2828

2929
#[inline]
3030
fn simd_eq(self, other: Self) -> Self::Mask {
@@ -46,65 +46,67 @@ macro_rules! impl_number {
4646

4747
impl_number! { f32, f64, u8, u16, u32, u64, usize, i8, i16, i32, i64, isize }
4848

49-
macro_rules! impl_mask {
50-
{ $($integer:ty),* } => {
51-
$(
52-
impl<const N: usize> SimdPartialEq for Mask<$integer, N>
53-
where
54-
LaneCount<N>: SupportedLaneCount,
55-
{
56-
type Mask = Self;
49+
// Masks compare lane-wise by comparing their underlying integer representations
50+
impl<T, const N: usize> SimdPartialEq for Mask<T, N>
51+
where
52+
T: SimdElement,
53+
LaneCount<N>: SupportedLaneCount,
54+
{
55+
type Mask = Self;
5756

58-
#[inline]
59-
fn simd_eq(self, other: Self) -> Self::Mask {
60-
// Safety: `self` is a vector, and the result of the comparison
61-
// is always a valid mask.
62-
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_eq(self.to_simd(), other.to_simd())) }
63-
}
57+
#[inline]
58+
fn simd_eq(self, other: Self) -> Self::Mask {
59+
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
60+
unsafe {
61+
Self::from_simd_unchecked(core::intrinsics::simd::simd_eq(
62+
self.to_simd(),
63+
other.to_simd(),
64+
))
65+
}
66+
}
6467

65-
#[inline]
66-
fn simd_ne(self, other: Self) -> Self::Mask {
67-
// Safety: `self` is a vector, and the result of the comparison
68-
// is always a valid mask.
69-
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_ne(self.to_simd(), other.to_simd())) }
70-
}
68+
#[inline]
69+
fn simd_ne(self, other: Self) -> Self::Mask {
70+
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
71+
unsafe {
72+
Self::from_simd_unchecked(core::intrinsics::simd::simd_ne(
73+
self.to_simd(),
74+
other.to_simd(),
75+
))
7176
}
72-
)*
7377
}
7478
}
7579

76-
impl_mask! { i8, i16, i32, i64, isize }
77-
7880
impl<T, const N: usize> SimdPartialEq for Simd<*const T, N>
7981
where
8082
LaneCount<N>: SupportedLaneCount,
8183
{
82-
type Mask = Mask<isize, N>;
84+
type Mask = Mask<*const T, N>;
8385

8486
#[inline]
8587
fn simd_eq(self, other: Self) -> Self::Mask {
86-
self.addr().simd_eq(other.addr())
88+
self.addr().simd_eq(other.addr()).cast::<*const T>()
8789
}
8890

8991
#[inline]
9092
fn simd_ne(self, other: Self) -> Self::Mask {
91-
self.addr().simd_ne(other.addr())
93+
self.addr().simd_ne(other.addr()).cast::<*const T>()
9294
}
9395
}
9496

9597
impl<T, const N: usize> SimdPartialEq for Simd<*mut T, N>
9698
where
9799
LaneCount<N>: SupportedLaneCount,
98100
{
99-
type Mask = Mask<isize, N>;
101+
type Mask = Mask<*mut T, N>;
100102

101103
#[inline]
102104
fn simd_eq(self, other: Self) -> Self::Mask {
103-
self.addr().simd_eq(other.addr())
105+
self.addr().simd_eq(other.addr()).cast::<*mut T>()
104106
}
105107

106108
#[inline]
107109
fn simd_ne(self, other: Self) -> Self::Mask {
108-
self.addr().simd_ne(other.addr())
110+
self.addr().simd_ne(other.addr()).cast::<*mut T>()
109111
}
110112
}

crates/core_simd/src/simd/cmp/ord.rs

Lines changed: 79 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::simd::{
2-
LaneCount, Mask, Select, Simd, SupportedLaneCount,
2+
LaneCount, Mask, Select, Simd, SimdElement, SupportedLaneCount,
33
cmp::SimdPartialEq,
44
ptr::{SimdConstPtr, SimdMutPtr},
55
};
@@ -152,94 +152,108 @@ macro_rules! impl_float {
152152

153153
impl_float! { f32, f64 }
154154

155-
macro_rules! impl_mask {
156-
{ $($integer:ty),* } => {
157-
$(
158-
impl<const N: usize> SimdPartialOrd for Mask<$integer, N>
159-
where
160-
LaneCount<N>: SupportedLaneCount,
161-
{
162-
#[inline]
163-
fn simd_lt(self, other: Self) -> Self::Mask {
164-
// Safety: `self` is a vector, and the result of the comparison
165-
// is always a valid mask.
166-
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_lt(self.to_simd(), other.to_simd())) }
167-
}
155+
impl<T, const N: usize> SimdPartialOrd for Mask<T, N>
156+
where
157+
T: SimdElement,
158+
LaneCount<N>: SupportedLaneCount,
159+
{
160+
#[inline]
161+
fn simd_lt(self, other: Self) -> Self::Mask {
162+
// Use intrinsic to avoid extra bounds on T.
163+
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
164+
unsafe {
165+
Self::from_simd_unchecked(core::intrinsics::simd::simd_lt(
166+
self.to_simd(),
167+
other.to_simd(),
168+
))
169+
}
170+
}
168171

169-
#[inline]
170-
fn simd_le(self, other: Self) -> Self::Mask {
171-
// Safety: `self` is a vector, and the result of the comparison
172-
// is always a valid mask.
173-
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_le(self.to_simd(), other.to_simd())) }
174-
}
172+
#[inline]
173+
fn simd_le(self, other: Self) -> Self::Mask {
174+
// Use intrinsic to avoid extra bounds on T.
175+
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
176+
unsafe {
177+
Self::from_simd_unchecked(core::intrinsics::simd::simd_le(
178+
self.to_simd(),
179+
other.to_simd(),
180+
))
181+
}
182+
}
175183

176-
#[inline]
177-
fn simd_gt(self, other: Self) -> Self::Mask {
178-
// Safety: `self` is a vector, and the result of the comparison
179-
// is always a valid mask.
180-
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_gt(self.to_simd(), other.to_simd())) }
181-
}
184+
#[inline]
185+
fn simd_gt(self, other: Self) -> Self::Mask {
186+
// Use intrinsic to avoid extra bounds on T.
187+
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
188+
unsafe {
189+
Self::from_simd_unchecked(core::intrinsics::simd::simd_gt(
190+
self.to_simd(),
191+
other.to_simd(),
192+
))
193+
}
194+
}
182195

183-
#[inline]
184-
fn simd_ge(self, other: Self) -> Self::Mask {
185-
// Safety: `self` is a vector, and the result of the comparison
186-
// is always a valid mask.
187-
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_ge(self.to_simd(), other.to_simd())) }
188-
}
196+
#[inline]
197+
fn simd_ge(self, other: Self) -> Self::Mask {
198+
// Use intrinsic to avoid extra bounds on T.
199+
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
200+
unsafe {
201+
Self::from_simd_unchecked(core::intrinsics::simd::simd_ge(
202+
self.to_simd(),
203+
other.to_simd(),
204+
))
189205
}
206+
}
207+
}
190208

191-
impl<const N: usize> SimdOrd for Mask<$integer, N>
192-
where
193-
LaneCount<N>: SupportedLaneCount,
194-
{
195-
#[inline]
196-
fn simd_max(self, other: Self) -> Self {
197-
self.simd_gt(other).select(other, self)
198-
}
209+
impl<T, const N: usize> SimdOrd for Mask<T, N>
210+
where
211+
T: SimdElement,
212+
LaneCount<N>: SupportedLaneCount,
213+
{
214+
#[inline]
215+
fn simd_max(self, other: Self) -> Self {
216+
self.simd_gt(other).select(other, self)
217+
}
199218

200-
#[inline]
201-
fn simd_min(self, other: Self) -> Self {
202-
self.simd_lt(other).select(other, self)
203-
}
219+
#[inline]
220+
fn simd_min(self, other: Self) -> Self {
221+
self.simd_lt(other).select(other, self)
222+
}
204223

205-
#[inline]
206-
#[track_caller]
207-
fn simd_clamp(self, min: Self, max: Self) -> Self {
208-
assert!(
209-
min.simd_le(max).all(),
210-
"each element in `min` must be less than or equal to the corresponding element in `max`",
211-
);
212-
self.simd_max(min).simd_min(max)
213-
}
214-
}
215-
)*
224+
#[inline]
225+
#[track_caller]
226+
fn simd_clamp(self, min: Self, max: Self) -> Self {
227+
assert!(
228+
min.simd_le(max).all(),
229+
"each element in `min` must be less than or equal to the corresponding element in `max`",
230+
);
231+
self.simd_max(min).simd_min(max)
216232
}
217233
}
218234

219-
impl_mask! { i8, i16, i32, i64, isize }
220-
221235
impl<T, const N: usize> SimdPartialOrd for Simd<*const T, N>
222236
where
223237
LaneCount<N>: SupportedLaneCount,
224238
{
225239
#[inline]
226240
fn simd_lt(self, other: Self) -> Self::Mask {
227-
self.addr().simd_lt(other.addr())
241+
self.addr().simd_lt(other.addr()).cast::<*const T>()
228242
}
229243

230244
#[inline]
231245
fn simd_le(self, other: Self) -> Self::Mask {
232-
self.addr().simd_le(other.addr())
246+
self.addr().simd_le(other.addr()).cast::<*const T>()
233247
}
234248

235249
#[inline]
236250
fn simd_gt(self, other: Self) -> Self::Mask {
237-
self.addr().simd_gt(other.addr())
251+
self.addr().simd_gt(other.addr()).cast::<*const T>()
238252
}
239253

240254
#[inline]
241255
fn simd_ge(self, other: Self) -> Self::Mask {
242-
self.addr().simd_ge(other.addr())
256+
self.addr().simd_ge(other.addr()).cast::<*const T>()
243257
}
244258
}
245259

@@ -274,22 +288,22 @@ where
274288
{
275289
#[inline]
276290
fn simd_lt(self, other: Self) -> Self::Mask {
277-
self.addr().simd_lt(other.addr())
291+
self.addr().simd_lt(other.addr()).cast::<*mut T>()
278292
}
279293

280294
#[inline]
281295
fn simd_le(self, other: Self) -> Self::Mask {
282-
self.addr().simd_le(other.addr())
296+
self.addr().simd_le(other.addr()).cast::<*mut T>()
283297
}
284298

285299
#[inline]
286300
fn simd_gt(self, other: Self) -> Self::Mask {
287-
self.addr().simd_gt(other.addr())
301+
self.addr().simd_gt(other.addr()).cast::<*mut T>()
288302
}
289303

290304
#[inline]
291305
fn simd_ge(self, other: Self) -> Self::Mask {
292-
self.addr().simd_ge(other.addr())
306+
self.addr().simd_ge(other.addr()).cast::<*mut T>()
293307
}
294308
}
295309

crates/core_simd/src/simd/num/float.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ macro_rules! impl_trait {
250250
where
251251
LaneCount<N>: SupportedLaneCount,
252252
{
253-
type Mask = Mask<<$mask_ty as SimdElement>::Mask, N>;
253+
type Mask = Mask<$ty, N>;
254254
type Scalar = $ty;
255255
type Bits = Simd<$bits_ty, N>;
256256
type Cast<T: SimdElement> = Simd<T, N>;
@@ -345,7 +345,7 @@ macro_rules! impl_trait {
345345
#[inline]
346346
fn is_sign_negative(self) -> Self::Mask {
347347
let sign_bits = self.to_bits() & Simd::splat((!0 >> 1) + 1);
348-
sign_bits.simd_gt(Simd::splat(0))
348+
sign_bits.simd_gt(Simd::splat(0)).cast::<$ty>()
349349
}
350350

351351
#[inline]
@@ -367,8 +367,11 @@ macro_rules! impl_trait {
367367
fn is_subnormal(self) -> Self::Mask {
368368
// On some architectures (e.g. armv7 and some ppc) subnormals are flushed to zero,
369369
// so this comparison must be done with integers.
370-
let not_zero = self.abs().to_bits().simd_ne(Self::splat(0.0).to_bits());
371-
not_zero & (self.to_bits() & Self::splat(Self::Scalar::INFINITY).to_bits()).simd_eq(Simd::splat(0))
370+
let not_zero = self.abs().to_bits().simd_ne(Self::splat(0.0).to_bits()).cast::<$ty>();
371+
let exp_zero = (self.to_bits() & Self::splat(Self::Scalar::INFINITY).to_bits())
372+
.simd_eq(Simd::splat(0))
373+
.cast::<$ty>();
374+
not_zero & exp_zero
372375
}
373376

374377
#[inline]

0 commit comments

Comments
 (0)