Skip to content

Commit

Permalink
(v2 of) Add BitsFromMask, promoting from detail::.
Browse files Browse the repository at this point in the history
Provide DFromM on all targets except SVE/RVV.
Also split mask_test into mask_set_test,
remove unused overload in scalar,
modernize overloads (SFINAE instead of type tags).
arm_sve required moving some sections earlier before their first usage.

PiperOrigin-RevId: 701919058
  • Loading branch information
jan-wassenberg authored and copybara-github committed Dec 2, 2024
1 parent b13a46f commit 62c0a79
Show file tree
Hide file tree
Showing 16 changed files with 1,122 additions and 980 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ HWY_TESTS = [
("hwy/tests/", "mask_combine_test"),
("hwy/tests/", "mask_convert_test"),
("hwy/tests/", "mask_mem_test"),
("hwy/tests/", "mask_set_test"),
("hwy/tests/", "mask_slide_test"),
("hwy/tests/", "mask_test"),
("hwy/tests/", "masked_arithmetic_test"),
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ set(HWY_TEST_FILES
hwy/tests/mask_combine_test.cc
hwy/tests/mask_convert_test.cc
hwy/tests/mask_mem_test.cc
hwy/tests/mask_set_test.cc
hwy/tests/mask_slide_test.cc
hwy/tests/mask_test.cc
hwy/tests/masked_arithmetic_test.cc
Expand Down
39 changes: 22 additions & 17 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,12 @@ encoding depends on the platform).
* <code>V **VecFromMask**(D, M m)</code>: returns 0 in lane `i` if `m[i] ==
false`, otherwise all bits set.

* <code>uint64_t **BitsFromMask**(D, M m)</code>: returns bits `b` such that
`(b >> i) & 1` indicates whether `m[i]` was set, and any remaining bits in
the `uint64_t` are zero. This is only available if `HWY_MAX_BYTES <= 64`,
because 512-bit vectors are the longest for which there are no more than 64
lanes and thus mask bits.

* <code>size_t **StoreMaskBits**(D, M m, uint8_t* p)</code>: stores a bit
array indicating whether `m[i]` is true, in ascending order of `i`, filling
the bits of each byte from least to most significant, then proceeding to the
Expand All @@ -1163,11 +1169,11 @@ encoding depends on the platform).
Mask&lt;DFrom&gt; m)</code>: Promotes `m` to a mask with a lane type of
`TFromD<DTo>`, `DFrom` is `Rebind<TFrom, DTo>`.

`PromoteMaskTo(d_to, d_from, m)` is equivalent to
`MaskFromVec(BitCast(d_to, PromoteTo(di_to, BitCast(di_from,
VecFromMask(d_from, m)))))`, where `di_from` is `RebindToSigned<DFrom>()`
and `di_from` is `RebindToSigned<DFrom>()`, but
`PromoteMaskTo(d_to, d_from, m)` is more efficient on some targets.
`PromoteMaskTo(d_to, d_from, m)` is equivalent to `MaskFromVec(BitCast(d_to,
PromoteTo(di_to, BitCast(di_from, VecFromMask(d_from, m)))))`, where
`di_from` is `RebindToSigned<DFrom>()` and `di_from` is
`RebindToSigned<DFrom>()`, but `PromoteMaskTo(d_to, d_from, m)` is more
efficient on some targets.

PromoteMaskTo requires that `sizeof(TFromD<DFrom>) < sizeof(TFromD<DTo>)` be
true.
Expand All @@ -1176,11 +1182,11 @@ encoding depends on the platform).
Mask&lt;DFrom&gt; m)</code>: Demotes `m` to a mask with a lane type of
`TFromD<DTo>`, `DFrom` is `Rebind<TFrom, DTo>`.

`DemoteMaskTo(d_to, d_from, m)` is equivalent to
`MaskFromVec(BitCast(d_to, DemoteTo(di_to, BitCast(di_from,
VecFromMask(d_from, m)))))`, where `di_from` is `RebindToSigned<DFrom>()`
and `di_from` is `RebindToSigned<DFrom>()`, but
`DemoteMaskTo(d_to, d_from, m)` is more efficient on some targets.
`DemoteMaskTo(d_to, d_from, m)` is equivalent to `MaskFromVec(BitCast(d_to,
DemoteTo(di_to, BitCast(di_from, VecFromMask(d_from, m)))))`, where
`di_from` is `RebindToSigned<DFrom>()` and `di_from` is
`RebindToSigned<DFrom>()`, but `DemoteMaskTo(d_to, d_from, m)` is more
efficient on some targets.

DemoteMaskTo requires that `sizeof(TFromD<DFrom>) > sizeof(TFromD<DTo>)` be
true.
Expand All @@ -1189,16 +1195,15 @@ encoding depends on the platform).
whose `LowerHalf` is the first argument and whose `UpperHalf` is the second
argument; `M2` is `Mask<Half<DFrom>>`; `DTo` is `Repartition<TTo, DFrom>`.

OrderedDemote2MasksTo requires that
`sizeof(TFromD<DTo>) == sizeof(TFromD<DFrom>) * 2` be true.
OrderedDemote2MasksTo requires that `sizeof(TFromD<DTo>) ==
sizeof(TFromD<DFrom>) * 2` be true.

`OrderedDemote2MasksTo(d_to, d_from, a, b)` is equivalent to
`MaskFromVec(BitCast(d_to, OrderedDemote2To(di_to, va, vb)))`, where `va` is
`BitCast(di_from, MaskFromVec(d_from, a))`, `vb` is
`BitCast(di_from, MaskFromVec(d_from, b))`, `di_to` is
`RebindToSigned<DTo>()`, and `di_from` is `RebindToSigned<DFrom>()`, but
`OrderedDemote2MasksTo(d_to, d_from, a, b)` is more efficient on some
targets.
`BitCast(di_from, MaskFromVec(d_from, a))`, `vb` is `BitCast(di_from,
MaskFromVec(d_from, b))`, `di_to` is `RebindToSigned<DTo>()`, and `di_from`
is `RebindToSigned<DFrom>()`, but `OrderedDemote2MasksTo(d_to, d_from, a,
b)` is more efficient on some targets.

OrderedDemote2MasksTo is only available if `HWY_TARGET != HWY_SCALAR` is
true.
Expand Down
129 changes: 64 additions & 65 deletions hwy/ops/arm_neon-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
// Arm NEON intrinsics are documented at:
// https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]

#include "hwy/base.h"
#include "hwy/ops/shared-inl.h"

HWY_DIAGNOSTICS(push)
Expand Down Expand Up @@ -8921,14 +8922,22 @@ HWY_INLINE uint64_t NibblesFromMask(D d, MFromD<D> mask) {
return nib & ((1ull << (d.MaxBytes() * 4)) - 1);
}

template <typename T>
HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, Mask128<T> mask) {
// Returns the lowest N for the BitsFromMask result.
template <class D>
constexpr uint64_t OnlyActive(D d, uint64_t bits) {
return (d.MaxBytes() >= 8) ? bits : (bits & ((1ull << d.MaxLanes()) - 1));
}

} // namespace detail

template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_D(D, 16)>
HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
alignas(16) static constexpr uint8_t kSliceLanes[16] = {
1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80,
};
const Full128<uint8_t> du;
const RebindToUnsigned<D> du;
const Vec128<uint8_t> values =
BitCast(du, VecFromMask(Full128<T>(), mask)) & Load(du, kSliceLanes);
BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes);

#if HWY_ARCH_ARM_A64
// Can't vaddv - we need two separate bytes (16 bits).
Expand All @@ -8945,126 +8954,114 @@ HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, Mask128<T> mask) {
#endif
}

template <typename T, size_t N, HWY_IF_V_SIZE_LE(T, N, 8)>
HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, Mask128<T, N> mask) {
template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_LE_D(D, 8)>
HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
// Upper lanes of partial loads are undefined. OnlyActive will fix this if
// we load all kSliceLanes so the upper lanes do not pollute the valid bits.
alignas(8) static constexpr uint8_t kSliceLanes[8] = {1, 2, 4, 8,
0x10, 0x20, 0x40, 0x80};
const DFromM<decltype(mask)> d;
const RebindToUnsigned<decltype(d)> du;
const Vec128<uint8_t, N> slice(Load(Full64<uint8_t>(), kSliceLanes).raw);
const Vec128<uint8_t, N> values = BitCast(du, VecFromMask(d, mask)) & slice;
using VU = VFromD<decltype(du)>;
const VU slice(Load(Full64<uint8_t>(), kSliceLanes).raw);
const VU values = BitCast(du, VecFromMask(d, mask)) & slice;

#if HWY_ARCH_ARM_A64
return vaddv_u8(values.raw);
return detail::OnlyActive(d, vaddv_u8(values.raw));
#else
const uint16x4_t x2 = vpaddl_u8(values.raw);
const uint32x2_t x4 = vpaddl_u16(x2);
const uint64x1_t x8 = vpaddl_u32(x4);
return vget_lane_u64(x8, 0);
return detail::OnlyActive(d, vget_lane_u64(x8, 0));
#endif
}

template <typename T>
HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, Mask128<T> mask) {
template <class D, HWY_IF_T_SIZE_D(D, 2), HWY_IF_V_SIZE_D(D, 16)>
HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
alignas(16) static constexpr uint16_t kSliceLanes[8] = {
1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80};
const Full128<T> d;
const Full128<uint16_t> du;
const RebindToUnsigned<D> du;
const Vec128<uint16_t> values =
BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes);
#if HWY_ARCH_ARM_A64
return vaddvq_u16(values.raw);
return detail::OnlyActive(d, vaddvq_u16(values.raw));
#else
const uint32x4_t x2 = vpaddlq_u16(values.raw);
const uint64x2_t x4 = vpaddlq_u32(x2);
return vgetq_lane_u64(x4, 0) + vgetq_lane_u64(x4, 1);
return detail::OnlyActive(d, vgetq_lane_u64(x4, 0) + vgetq_lane_u64(x4, 1));
#endif
}

template <typename T, size_t N, HWY_IF_V_SIZE_LE(T, N, 8)>
HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, Mask128<T, N> mask) {
template <class D, HWY_IF_T_SIZE_D(D, 2), HWY_IF_V_SIZE_LE_D(D, 8)>
HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
// Upper lanes of partial loads are undefined. OnlyActive will fix this if
// we load all kSliceLanes so the upper lanes do not pollute the valid bits.
alignas(8) static constexpr uint16_t kSliceLanes[4] = {1, 2, 4, 8};
const DFromM<decltype(mask)> d;
const RebindToUnsigned<decltype(d)> du;
const Vec128<uint16_t, N> slice(Load(Full64<uint16_t>(), kSliceLanes).raw);
const Vec128<uint16_t, N> values = BitCast(du, VecFromMask(d, mask)) & slice;
using VU = VFromD<decltype(du)>;
const VU slice(Load(Full64<uint16_t>(), kSliceLanes).raw);
const VU values = BitCast(du, VecFromMask(d, mask)) & slice;
#if HWY_ARCH_ARM_A64
return vaddv_u16(values.raw);
return detail::OnlyActive(d, vaddv_u16(values.raw));
#else
const uint32x2_t x2 = vpaddl_u16(values.raw);
const uint64x1_t x4 = vpaddl_u32(x2);
return vget_lane_u64(x4, 0);
return detail::OnlyActive(d, vget_lane_u64(x4, 0));
#endif
}

template <typename T>
HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, Mask128<T> mask) {
template <class D, HWY_IF_T_SIZE_D(D, 4), HWY_IF_V_SIZE_D(D, 16)>
HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
alignas(16) static constexpr uint32_t kSliceLanes[4] = {1, 2, 4, 8};
const Full128<T> d;
const Full128<uint32_t> du;
const RebindToUnsigned<D> du;
const Vec128<uint32_t> values =
BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes);
#if HWY_ARCH_ARM_A64
return vaddvq_u32(values.raw);
return detail::OnlyActive(d, vaddvq_u32(values.raw));
#else
const uint64x2_t x2 = vpaddlq_u32(values.raw);
return vgetq_lane_u64(x2, 0) + vgetq_lane_u64(x2, 1);
return detail::OnlyActive(d, vgetq_lane_u64(x2, 0) + vgetq_lane_u64(x2, 1));
#endif
}

template <typename T, size_t N, HWY_IF_V_SIZE_LE(T, N, 8)>
HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, Mask128<T, N> mask) {
template <class D, HWY_IF_T_SIZE_D(D, 4), HWY_IF_V_SIZE_LE_D(D, 8)>
HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
// Upper lanes of partial loads are undefined. OnlyActive will fix this if
// we load all kSliceLanes so the upper lanes do not pollute the valid bits.
alignas(8) static constexpr uint32_t kSliceLanes[2] = {1, 2};
const DFromM<decltype(mask)> d;
const RebindToUnsigned<decltype(d)> du;
const Vec128<uint32_t, N> slice(Load(Full64<uint32_t>(), kSliceLanes).raw);
const Vec128<uint32_t, N> values = BitCast(du, VecFromMask(d, mask)) & slice;
using VU = VFromD<decltype(du)>;
const VU slice(Load(Full64<uint32_t>(), kSliceLanes).raw);
const VU values = BitCast(du, VecFromMask(d, mask)) & slice;
#if HWY_ARCH_ARM_A64
return vaddv_u32(values.raw);
return detail::OnlyActive(d, vaddv_u32(values.raw));
#else
const uint64x1_t x2 = vpaddl_u32(values.raw);
return vget_lane_u64(x2, 0);
return detail::OnlyActive(d, vget_lane_u64(x2, 0));
#endif
}

template <typename T>
HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, Mask128<T> m) {
template <class D, HWY_IF_T_SIZE_D(D, 8), HWY_IF_V_SIZE_D(D, 16)>
HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
alignas(16) static constexpr uint64_t kSliceLanes[2] = {1, 2};
const Full128<T> d;
const Full128<uint64_t> du;
const RebindToUnsigned<decltype(d)> du;
const Vec128<uint64_t> values =
BitCast(du, VecFromMask(d, m)) & Load(du, kSliceLanes);
BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes);
#if HWY_ARCH_ARM_A64
return vaddvq_u64(values.raw);
return detail::OnlyActive(d, vaddvq_u64(values.raw));
#else
return vgetq_lane_u64(values.raw, 0) + vgetq_lane_u64(values.raw, 1);
return detail::OnlyActive(
d, vgetq_lane_u64(values.raw, 0) + vgetq_lane_u64(values.raw, 1));
#endif
}

template <typename T>
HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, Mask128<T, 1> m) {
const Full64<T> d;
const Full64<uint64_t> du;
const Vec64<uint64_t> values = BitCast(du, VecFromMask(d, m)) & Set(du, 1);
template <class D, HWY_IF_T_SIZE_D(D, 8), HWY_IF_V_SIZE_LE_D(D, 8)>
HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
const RebindToUnsigned<decltype(d)> du;
const Vec64<uint64_t> values = BitCast(du, VecFromMask(d, mask)) & Set(du, 1);
return vget_lane_u64(values.raw, 0);
}

// Returns the lowest N for the BitsFromMask result.
template <typename T, size_t N>
constexpr uint64_t OnlyActive(uint64_t bits) {
return ((N * sizeof(T)) >= 8) ? bits : (bits & ((1ull << N) - 1));
}

template <typename T, size_t N>
HWY_INLINE uint64_t BitsFromMask(Mask128<T, N> mask) {
return OnlyActive<T, N>(BitsFromMask(hwy::SizeTag<sizeof(T)>(), mask));
}
namespace detail {

// Returns number of lanes whose mask is set.
//
Expand Down Expand Up @@ -9184,7 +9181,7 @@ HWY_API intptr_t FindLastTrue(D d, MFromD<D> mask) {
// `p` points to at least 8 writable bytes.
template <class D>
HWY_API size_t StoreMaskBits(D d, MFromD<D> mask, uint8_t* bits) {
const uint64_t mask_bits = detail::BitsFromMask(mask);
const uint64_t mask_bits = BitsFromMask(d, mask);
const size_t kNumBytes = (d.MaxLanes() + 7) / 8;
CopyBytes<kNumBytes>(&mask_bits, bits);
return kNumBytes;
Expand Down Expand Up @@ -9672,7 +9669,8 @@ HWY_API Vec128<T, N> Compress(Vec128<T, N> v, Mask128<T, N> mask) {
// General case, 2 or 4 byte lanes
template <typename T, size_t N, HWY_IF_T_SIZE_ONE_OF(T, (1 << 2) | (1 << 4))>
HWY_API Vec128<T, N> Compress(Vec128<T, N> v, Mask128<T, N> mask) {
return detail::Compress(v, detail::BitsFromMask(mask));
const DFromV<decltype(v)> d;
return detail::Compress(v, BitsFromMask(d, mask));
}

// Single lane: no-op
Expand All @@ -9696,12 +9694,13 @@ HWY_API Vec128<T> CompressNot(Vec128<T> v, Mask128<T> mask) {
// General case, 2 or 4 byte lanes
template <typename T, size_t N, HWY_IF_T_SIZE_ONE_OF(T, (1 << 2) | (1 << 4))>
HWY_API Vec128<T, N> CompressNot(Vec128<T, N> v, Mask128<T, N> mask) {
const DFromV<decltype(v)> d;
// For partial vectors, we cannot pull the Not() into the table because
// BitsFromMask clears the upper bits.
if (N < 16 / sizeof(T)) {
return detail::Compress(v, detail::BitsFromMask(Not(mask)));
return detail::Compress(v, BitsFromMask(d, Not(mask)));
}
return detail::CompressNot(v, detail::BitsFromMask(mask));
return detail::CompressNot(v, BitsFromMask(d, mask));
}

// ------------------------------ CompressBlocksNot
Expand Down Expand Up @@ -9729,7 +9728,7 @@ HWY_INLINE Vec128<T, N> CompressBits(Vec128<T, N> v,
template <class D, HWY_IF_NOT_T_SIZE_D(D, 1)>
HWY_API size_t CompressStore(VFromD<D> v, MFromD<D> mask, D d,
TFromD<D>* HWY_RESTRICT unaligned) {
const uint64_t mask_bits = detail::BitsFromMask(mask);
const uint64_t mask_bits = BitsFromMask(d, mask);
StoreU(detail::Compress(v, mask_bits), d, unaligned);
return PopCount(mask_bits);
}
Expand All @@ -9739,7 +9738,7 @@ template <class D, HWY_IF_NOT_T_SIZE_D(D, 1)>
HWY_API size_t CompressBlendedStore(VFromD<D> v, MFromD<D> m, D d,
TFromD<D>* HWY_RESTRICT unaligned) {
const RebindToUnsigned<decltype(d)> du; // so we can support fp16/bf16
const uint64_t mask_bits = detail::BitsFromMask(m);
const uint64_t mask_bits = BitsFromMask(d, m);
const size_t count = PopCount(mask_bits);
const MFromD<D> store_mask = RebindMask(d, FirstN(du, count));
const VFromD<decltype(du)> compressed =
Expand Down
Loading

0 comments on commit 62c0a79

Please sign in to comment.