Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move self_greater to the is_greater callable #2055

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions include/eve/arch/cpu/wide.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <eve/module/core/regular/logical_not.hpp>
#include <eve/module/core/regular/shl.hpp>
#include <eve/module/core/regular/shr.hpp>
#include <eve/module/core/regular/is_greater.hpp>
#include <eve/memory/soa_ptr.hpp>
#include <eve/traits/product_type.hpp>

Expand Down Expand Up @@ -890,32 +891,32 @@ namespace eve
}

//! @brief Element-wise greater-than comparison between eve::wide
friend EVE_FORCEINLINE auto operator>(wide v, wide w) noexcept
friend EVE_FORCEINLINE auto operator>(wide a, wide b) noexcept
#if !defined(EVE_DOXYGEN_INVOKED)
requires(supports_ordering_v<Type>)
#endif
{
return detail::self_greater(v, w);
return is_greater(a, b);
}

//! @brief Element-wise greater-than comparison between a eve::wide and a scalar
template<scalar_value S>
friend EVE_FORCEINLINE auto operator>(wide v, S w) noexcept
friend EVE_FORCEINLINE auto operator>(wide w, S s) noexcept
#if !defined(EVE_DOXYGEN_INVOKED)
requires(supports_ordering_v<Type>)
#endif
{
return v > wide {w};
return is_greater(w, s);
}

//! @brief Element-wise greater-than comparison between a scalar and a eve::wide
template<scalar_value S>
friend EVE_FORCEINLINE auto operator>(S v, wide w) noexcept
friend EVE_FORCEINLINE auto operator>(S s, wide w) noexcept
#if !defined(EVE_DOXYGEN_INVOKED)
requires(supports_ordering_v<Type>)
#endif
{
return wide {v} > w;
return is_greater(s, w);
}

//! @brief Element-wise greater-or-equal comparison between eve::wide
Expand Down
35 changes: 0 additions & 35 deletions include/eve/detail/function/simd/arm/neon/friends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,41 +62,6 @@ namespace eve::detail
return !(v == w);
}

template<typename T, typename N>
EVE_FORCEINLINE logical<wide<T, N>> self_greater( wide<T, N> v
, wide<T, N> w
) noexcept
requires arm_abi<abi_t<T, N>>
{
constexpr auto cat = categorize<wide<T, N>>();

if constexpr( cat == category::int32x4 ) return vcgtq_s32(v, w);
else if constexpr( cat == category::int16x8 ) return vcgtq_s16(v, w);
else if constexpr( cat == category::int8x16 ) return vcgtq_s8(v, w);
else if constexpr( cat == category::uint32x4 ) return vcgtq_u32(v, w);
else if constexpr( cat == category::uint16x8 ) return vcgtq_u16(v, w);
else if constexpr( cat == category::uint8x16 ) return vcgtq_u8(v, w);
else if constexpr( cat == category::float32x4) return vcgtq_f32(v, w);
else if constexpr( cat == category::int32x2 ) return vcgt_s32(v, w);
else if constexpr( cat == category::int16x4 ) return vcgt_s16(v, w);
else if constexpr( cat == category::int8x8 ) return vcgt_s8(v, w);
else if constexpr( cat == category::uint32x2 ) return vcgt_u32(v, w);
else if constexpr( cat == category::uint16x4 ) return vcgt_u16(v, w);
else if constexpr( cat == category::uint8x8 ) return vcgt_u8(v, w);
else if constexpr( cat == category::float32x2) return vcgt_f32(v, w);
else if constexpr( current_api >= asimd)
{
if constexpr( cat == category::float64x1) return vcgt_f64(v, w);
else if constexpr( cat == category::int64x1) return vcgt_s64(v, w);
else if constexpr( cat == category::uint64x1) return vcgt_u64(v, w);
else if constexpr( cat == category::float64x2) return vcgtq_f64(v, w);
else if constexpr( cat == category::int64x2) return vcgtq_s64(v, w);
else if constexpr( cat == category::uint64x2) return vcgtq_u64(v, w);
}
else if constexpr( sizeof(T) == 8 )
return map([]<typename E>(E const& e, E const& f){ return as_logical_t<E>(e > f); }, v, w);
}

template<typename T, typename N>
EVE_FORCEINLINE logical<wide<T, N>> self_geq(wide<T, N> v, wide<T, N> w) noexcept
requires arm_abi<abi_t<T, N>>
Expand Down
5 changes: 0 additions & 5 deletions include/eve/detail/function/simd/arm/sve/friends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ EVE_FORCEINLINE auto
self_neq(wide<T, N> v, wide<T, N> w) noexcept -> as_logical_t<wide<T, N>>
requires sve_abi<abi_t<T, N>> { return svcmpne(sve_true<T>(), v, w); }

template<arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE auto
self_greater(wide<T, N> v, wide<T, N> w) noexcept -> as_logical_t<wide<T, N>>
requires sve_abi<abi_t<T, N>> { return svcmpgt(sve_true<T>(), v, w); }

template<arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE auto
self_leq(wide<T, N> v, wide<T, N> w) noexcept -> as_logical_t<wide<T, N>>
Expand Down
14 changes: 0 additions & 14 deletions include/eve/detail/function/simd/common/friends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,6 @@ namespace eve::detail
}
}

template<simd_value Wide>
EVE_FORCEINLINE auto self_greater(Wide const& v,Wide const& w) noexcept
{
if constexpr( product_type<Wide> )
{
return convert(kumi::to_tuple(v) > kumi::to_tuple(w), as_element<as_logical_t<Wide>>());
}
else
{
constexpr auto gt = []<typename E>(E const& e, E const& f) { return as_logical_t<E>(e > f); };
return apply_over(gt, v, w);
}
}

template<simd_value Wide>
EVE_FORCEINLINE auto self_geq(Wide const& v,Wide const& w) noexcept
{
Expand Down
7 changes: 0 additions & 7 deletions include/eve/detail/function/simd/ppc/friends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ namespace eve::detail
return logical<wide<T,N>>(vec_cmpne(v.storage(), w.storage()));
}

template<arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE auto self_greater(wide<T, N> const &v, wide<T, N> const &w) noexcept
requires ppc_abi<abi_t<T, N>>
{
return logical<wide<T,N>>(vec_cmpgt(v.storage(), w.storage()));
}

template<arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE auto self_geq(wide<T, N> const &v, wide<T, N> const &w) noexcept
requires ppc_abi<abi_t<T, N>>
Expand Down
31 changes: 0 additions & 31 deletions include/eve/detail/function/simd/riscv/friends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,6 @@ namespace eve::detail

// *_impl in separate functions, as otherwise compiler can not
// choose overload between riscv-specific and common one
template<plain_scalar_value T, typename N, value U>
EVE_FORCEINLINE auto
self_greater_impl(wide<T, N> lhs, U rhs) noexcept -> logical<wide<T, N>>
requires rvv_abi<abi_t<T, N>> && (std::same_as<wide<T, N>, U> || scalar_value<U>)
{
if constexpr( scalar_value<U> && !std::same_as<T, U> )
return self_greater(lhs, static_cast<T>(rhs));
else
{
constexpr auto c = categorize<wide<T, N>>();
if constexpr( match(c, category::int_) ) return __riscv_vmsgt(lhs, rhs, N::value);
else if constexpr( match(c, category::uint_) ) return __riscv_vmsgtu(lhs, rhs, N::value);
else if constexpr( match(c, category::float_) ) return __riscv_vmfgt(lhs, rhs, N::value);
}
}

template<plain_scalar_value T, typename N>
EVE_FORCEINLINE auto
self_greater(wide<T, N> lhs, wide<T, N> rhs) noexcept -> logical<wide<T, N>>
requires rvv_abi<abi_t<T, N>>
{
return self_greater_impl(lhs, rhs);
}

template<plain_scalar_value T, typename N>
EVE_FORCEINLINE auto
self_greater(wide<T, N> lhs, std::convertible_to<T> auto rhs) noexcept -> logical<wide<T, N>>
requires rvv_abi<abi_t<T, N>>
{
return self_greater_impl(lhs, rhs);
}

template<plain_scalar_value T, typename N, value U>
EVE_FORCEINLINE auto
Expand Down
83 changes: 0 additions & 83 deletions include/eve/detail/function/simd/x86/friends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,89 +179,6 @@ self_neq(logical<wide<T, N>> v, logical<wide<T, N>> w) noexcept requires x86_abi
else { return bit_cast(v.bits() ^ w.bits(), as(v)); }
}

//================================================================================================
template<arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE as_logical_t<wide<T, N>>
self_greater(wide<T, N> v, wide<T, N> w) noexcept requires x86_abi<abi_t<T, N>>
{
constexpr auto c = categorize<wide<T, N>>();
constexpr auto f = to_integer(cmp_flt::gt_oq);

if constexpr( current_api >= avx512 )
{
if constexpr( c == category::float32x16 ) return mask16 {_mm512_cmp_ps_mask(v, w, f)};
else if constexpr( c == category::float32x8 ) return mask8 {_mm256_cmp_ps_mask(v, w, f)};
else if constexpr( c == category::float32x4 ) return mask8 {_mm_cmp_ps_mask(v, w, f)};
else if constexpr( c == category::float64x8 ) return mask8 {_mm512_cmp_pd_mask(v, w, f)};
else if constexpr( c == category::float64x4 ) return mask8 {_mm256_cmp_pd_mask(v, w, f)};
else if constexpr( c == category::float64x2 ) return mask8 {_mm_cmp_pd_mask(v, w, f)};
else if constexpr( c == category::uint64x8 ) return mask8 {_mm512_cmpgt_epu64_mask(v, w)};
else if constexpr( c == category::uint64x4 ) return mask8 {_mm256_cmpgt_epu64_mask(v, w)};
else if constexpr( c == category::uint64x2 ) return mask8 {_mm_cmpgt_epu64_mask(v, w)};
else if constexpr( c == category::uint32x16 ) return mask16 {_mm512_cmpgt_epu32_mask(v, w)};
else if constexpr( c == category::uint32x8 ) return mask8 {_mm256_cmpgt_epu32_mask(v, w)};
else if constexpr( c == category::uint32x4 ) return mask8 {_mm_cmpgt_epu32_mask(v, w)};
else if constexpr( c == category::uint16x32 ) return mask32 {_mm512_cmpgt_epu16_mask(v, w)};
else if constexpr( c == category::uint16x16 ) return mask16 {_mm256_cmpgt_epu16_mask(v, w)};
else if constexpr( c == category::uint16x8 ) return mask8 {_mm_cmpgt_epu16_mask(v, w)};
else if constexpr( c == category::uint8x64 ) return mask64 {_mm512_cmpgt_epu8_mask(v, w)};
else if constexpr( c == category::uint8x32 ) return mask32 {_mm256_cmpgt_epu8_mask(v, w)};
else if constexpr( c == category::uint8x16 ) return mask16 {_mm_cmpgt_epu8_mask(v, w)};
else if constexpr( c == category::int64x8 ) return mask8 {_mm512_cmpgt_epi64_mask(v, w)};
else if constexpr( c == category::int64x4 ) return mask8 {_mm256_cmpgt_epi64_mask(v, w)};
else if constexpr( c == category::int64x2 ) return mask8 {_mm_cmpgt_epi64_mask(v, w)};
else if constexpr( c == category::int32x16 ) return mask16 {_mm512_cmpgt_epi32_mask(v, w)};
else if constexpr( c == category::int32x8 ) return mask8 {_mm256_cmpgt_epi32_mask(v, w)};
else if constexpr( c == category::int32x4 ) return mask8 {_mm_cmpgt_epi32_mask(v, w)};
else if constexpr( c == category::int16x32 ) return mask32 {_mm512_cmpgt_epi16_mask(v, w)};
else if constexpr( c == category::int16x16 ) return mask16 {_mm256_cmpgt_epi16_mask(v, w)};
else if constexpr( c == category::int16x8 ) return mask8 {_mm_cmpgt_epi16_mask(v, w)};
else if constexpr( c == category::int8x64 ) return mask64 {_mm512_cmpgt_epi8_mask(v, w)};
else if constexpr( c == category::int8x32 ) return mask32 {_mm256_cmpgt_epi8_mask(v, w)};
else if constexpr( c == category::int8x16 ) return mask16 {_mm_cmpgt_epi8_mask(v, w)};
}
else
{
if constexpr( c == category::float32x8 ) return _mm256_cmp_ps(v, w, f);
else if constexpr( c == category::float64x4 ) return _mm256_cmp_pd(v, w, f);
else if constexpr( c == category::float32x4 ) return _mm_cmpgt_ps(v, w);
else if constexpr( c == category::float64x2 ) return _mm_cmpgt_pd(v, w);
else
{
constexpr auto use_avx2 = current_api >= avx2;
constexpr auto use_sse4 = current_api >= sse4_2;

constexpr auto gt = []<typename E>(E ev, E fv) { return as_logical_t<E>(ev > fv); };

[[maybe_unused]] auto unsigned_cmp = [](auto vv, auto vw)
{
using l_t = logical<wide<T, N>>;
auto const sm = signmask(as<as_integer_t<wide<T, N>, signed>>());
return bit_cast((bit_cast(vv, as(sm)) - sm) > (bit_cast(vw, as(sm)) - sm), as<l_t> {});
};

if constexpr( use_avx2 && c == category::int64x4 ) return _mm256_cmpgt_epi64(v, w);
else if constexpr( use_avx2 && c == category::uint64x4 ) return unsigned_cmp(v, w);
else if constexpr( use_avx2 && c == category::int32x8 ) return _mm256_cmpgt_epi32(v, w);
else if constexpr( use_avx2 && c == category::uint32x8 ) return unsigned_cmp(v, w);
else if constexpr( use_avx2 && c == category::int16x16 ) return _mm256_cmpgt_epi16(v, w);
else if constexpr( use_avx2 && c == category::uint16x16 ) return unsigned_cmp(v, w);
else if constexpr( use_avx2 && c == category::int8x32 ) return _mm256_cmpgt_epi8(v, w);
else if constexpr( use_avx2 && c == category::uint8x32 ) return unsigned_cmp(v, w);
else if constexpr( use_sse4 && c == category::int64x2 ) return _mm_cmpgt_epi64(v, w);
else if constexpr( c == category::int64x2 ) return map(gt, v, w);
else if constexpr( c == category::int32x4 ) return _mm_cmpgt_epi32(v, w);
else if constexpr( c == category::int16x8 ) return _mm_cmpgt_epi16(v, w);
else if constexpr( c == category::int8x16 ) return _mm_cmpgt_epi8(v, w);
else if constexpr( c == category::uint64x2 ) return unsigned_cmp(v, w);
else if constexpr( c == category::uint32x4 ) return unsigned_cmp(v, w);
else if constexpr( c == category::uint16x8 ) return unsigned_cmp(v, w);
else if constexpr( c == category::uint8x16 ) return unsigned_cmp(v, w);
else return aggregate(gt, v, w);
}
}
}

//================================================================================================
template<arithmetic_scalar_value T, typename N>
EVE_FORCEINLINE as_logical_t<wide<T, N>>
Expand Down
4 changes: 4 additions & 0 deletions include/eve/forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ struct ignore_first;
struct ignore_last;
struct keep_between;
struct ignore_extrema;

template<typename O> struct is_greater_t;
template<typename O> struct is_less_t;

}

namespace eve::detail
Expand Down
35 changes: 35 additions & 0 deletions include/eve/module/core/regular/impl/is_greater.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

#include <eve/module/core/regular/abs.hpp>
#include <eve/module/core/regular/if_else.hpp>
#include <eve/module/core/regular/fam.hpp>
#include <eve/module/core/regular/next.hpp>
#include <eve/module/core/regular/max.hpp>
#include <eve/module/core/regular/is_less.hpp>
#include <eve/traits/as_logical.hpp>
#include <eve/module/core/detail/tolerance.hpp>

namespace eve::detail
{
template<callable_options O, value T>
EVE_FORCEINLINE constexpr as_logical_t<T> is_greater_(EVE_REQUIRES(cpu_), O const & o, T a, T b) noexcept
{
if constexpr(O::contains(definitely))
{
auto tol = o[definitely].value(T{});
if constexpr(integral_value<decltype(tol)>) return a > eve::next(b, tol);
else return a > fam(b, tol, eve::max(eve::abs(a), eve::abs(b)));
}
DenisYaroshevskiy marked this conversation as resolved.
Show resolved Hide resolved
else
{
return is_less(b, a);
}
}
}
5 changes: 5 additions & 0 deletions include/eve/module/core/regular/impl/is_less.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
//==================================================================================================
#pragma once

#include <eve/module/core/regular/if_else.hpp>
#include <eve/module/core/regular/fam.hpp>
#include <eve/module/core/regular/prev.hpp>
#include <eve/module/core/regular/max.hpp>

namespace eve::detail
{
template<callable_options O, typename T>
Expand Down
5 changes: 2 additions & 3 deletions include/eve/module/core/regular/impl/max.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <eve/module/core/regular/bit_and.hpp>
#include <eve/module/core/regular/if_else.hpp>
#include <eve/module/core/regular/is_eqz.hpp>
#include <eve/module/core/regular/is_greater.hpp>
#include <eve/module/core/regular/is_less.hpp>
#include <eve/module/core/regular/is_nan.hpp>
#include <eve/module/core/regular/is_ordered.hpp>
Expand Down Expand Up @@ -98,8 +97,8 @@ namespace eve::detail
EVE_FORCEINLINE constexpr auto
max_(EVE_REQUIRES(cpu_), O const &, Callable const & f) noexcept
{
if constexpr( std::same_as<Callable, eve::callable_is_less_> ) return eve::max;
else if constexpr( std::same_as<Callable, eve::callable_is_greater_> ) return eve::min;
if constexpr( std::same_as<Callable, eve::is_less_t<eve::options<>>> ) return eve::max;
else if constexpr( std::same_as<Callable, eve::is_greater_t<eve::options<>>> ) return eve::min;
else
{
return [f](auto x, auto y){ return eve::if_else(f(y, x), x, y); };
Expand Down
7 changes: 4 additions & 3 deletions include/eve/module/core/regular/impl/min.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <eve/module/core/regular/if_else.hpp>
#include <eve/module/core/regular/is_eqz.hpp>
#include <eve/module/core/regular/is_nan.hpp>
#include <eve/module/core/regular/is_greater.hpp>
#include <eve/module/core/regular/is_less.hpp>
#include <eve/module/core/regular/max.hpp>
#include <eve/module/core/regular/is_ordered.hpp>
Expand Down Expand Up @@ -102,9 +101,11 @@ namespace eve::detail
EVE_FORCEINLINE constexpr auto
min_(EVE_REQUIRES(cpu_), O const &, Callable const& f) noexcept
{
if constexpr( std::same_as<Callable, callable_is_less_> ) return eve::min;
else if constexpr( std::same_as<Callable, callable_is_greater_> ) return eve::max;
if constexpr( std::same_as<Callable, eve::is_less_t<eve::options<>>> ) return eve::min;
else if constexpr( std::same_as<Callable, eve::is_greater_t<eve::options<>>> ) return eve::max;
else
{
return [f](auto x, auto y){ return eve::if_else(f(y, x), y, x); };
}
}
}
Loading