Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add traits to help type checks #117

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ void normalize_impl(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ViewType>
auto get_coefficients(ViewType, Direction direction,
Normalization normalization, std::size_t fft_size) {
using value_type =
KokkosFFT::Impl::real_type_t<typename ViewType::non_const_value_type>;
using value_type = KokkosFFT::Impl::base_floating_point_type<
typename ViewType::non_const_value_type>;
value_type coef = 1;
[[maybe_unused]] bool to_normalize = false;

Expand Down
134 changes: 134 additions & 0 deletions common/src/KokkosFFT_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception

#ifndef KOKKOSFFT_TRAITS_HPP
#define KOKKOSFFT_TRAITS_HPP

#include <Kokkos_Core.hpp>

namespace KokkosFFT {
namespace Impl {
template <typename T>
struct base_floating_point {
using value_type = T;
};

template <typename T>
struct base_floating_point<Kokkos::complex<T>> {
using value_type = T;
};

/// \brief Helper to extract the base floating point type from a complex type
template <typename T>
using base_floating_point_type = typename base_floating_point<T>::value_type;

template <typename T, typename Enable = void>
struct is_real : std::false_type {};

template <typename T>
struct is_real<
T, std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>>>
: std::true_type {};

/// \brief Helper to check if a type is an acceptable real type (float/double)
/// for Kokkos-FFT
template <typename T>
inline constexpr bool is_real_v = is_real<T>::value;

template <typename T, typename Enable = void>
struct is_complex : std::false_type {};

template <typename T>
struct is_complex<
Kokkos::complex<T>,
std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>>>
: std::true_type {};

/// \brief Helper to check if a type is an acceptable complex type
/// (Kokkos::complex<float>/Kokkos::complex<double>) for Kokkos-FFT
template <typename T>
inline constexpr bool is_complex_v = is_complex<T>::value;

// is value type admissible for KokkosFFT
template <typename T, typename Enable = void>
struct is_admissible_value_type : std::false_type {};

template <typename T>
struct is_admissible_value_type<
T, std::enable_if_t<is_real_v<T> || is_complex_v<T>>> : std::true_type {};

template <typename T>
struct is_admissible_value_type<
T, std::enable_if_t<Kokkos::is_view<T>::value &&
(is_real_v<typename T::non_const_value_type> ||
is_complex_v<typename T::non_const_value_type>)>>
: std::true_type {};

/// \brief Helper to check if a type is an acceptable value type
/// (float/double/Kokkos::complex<float>/Kokkos::complex<double>) for Kokkos-FFT
/// When applied to Kokkos::View, then check if a value type is an
/// acceptable real/complex type.
template <typename T>
inline constexpr bool is_admissible_value_type_v =
is_admissible_value_type<T>::value;

template <typename ViewType, typename Enable = void>
struct is_layout_left_or_right : std::false_type {};

template <typename ViewType>
struct is_layout_left_or_right<
ViewType,
std::enable_if_t<
Kokkos::is_view<ViewType>::value &&
(std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutLeft> ||
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>)>>
: std::true_type {};

/// \brief Helper to check if a View layout is an acceptable layout type
/// (Kokkos::LayoutLeft/Kokkos::LayoutRight) for Kokkos-FFT
template <typename ViewType>
inline constexpr bool is_layout_left_or_right_v =
is_layout_left_or_right<ViewType>::value;

template <typename ViewType, typename Enable = void>
struct is_admissible_view : std::false_type {};

template <typename ViewType>
struct is_admissible_view<
ViewType, std::enable_if_t<Kokkos::is_view<ViewType>::value &&
is_layout_left_or_right_v<ViewType> &&
is_admissible_value_type_v<ViewType>>>
: std::true_type {};

/// \brief Helper to check if a View is an acceptable for Kokkos-FFT. Values and
/// layout are checked
template <typename ViewType>
inline constexpr bool is_admissible_view_v =
is_admissible_view<ViewType>::value;

/// \brief Helper to define a managable View type from the original view type
template <typename T>
struct managable_view_type {
using type = Kokkos::View<typename T::data_type, typename T::array_layout,
typename T::memory_space,
Kokkos::MemoryTraits<T::memory_traits::impl_value &
~unsigned(Kokkos::Unmanaged)>>;
};

/// \brief Helper to define a complex 1D View type from a real/complex 1D View
/// type, while keeping other properties
template <typename ExecutionSpace, typename ViewType,
std::enable_if_t<ViewType::rank() == 1, std::nullptr_t> = nullptr>
struct complex_view_type {
using value_type = typename ViewType::non_const_value_type;
using float_type = KokkosFFT::Impl::base_floating_point_type<value_type>;
using complex_type = Kokkos::complex<float_type>;
using array_layout_type = typename ViewType::array_layout;
using type = Kokkos::View<complex_type*, array_layout_type, ExecutionSpace>;
};

} // namespace Impl
} // namespace KokkosFFT

#endif
52 changes: 1 addition & 51 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,60 +10,10 @@
#include <set>
#include <algorithm>
#include <numeric>
#include "KokkosFFT_traits.hpp"

namespace KokkosFFT {
namespace Impl {
template <typename T>
struct real_type {
using type = T;
};

template <typename T>
struct real_type<Kokkos::complex<T>> {
using type = T;
};

template <typename T>
struct managable_view_type {
using type = Kokkos::View<typename T::data_type, typename T::array_layout,
typename T::memory_space,
Kokkos::MemoryTraits<T::memory_traits::impl_value &
~unsigned(Kokkos::Unmanaged)>>;
};

template <typename T>
using real_type_t = typename real_type<T>::type;

template <typename T>
struct is_complex : std::false_type {};

template <typename T>
struct is_complex<Kokkos::complex<T>> : std::true_type {};

template <typename ViewType, typename Enable = void>
struct is_layout_left_or_right : std::false_type {};

template <typename ViewType>
struct is_layout_left_or_right<
ViewType,
std::enable_if_t<
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutLeft> ||
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>>>
: std::true_type {};

template <typename ViewType>
inline constexpr bool is_layout_left_or_right_v =
is_layout_left_or_right<ViewType>::value;

template <typename ExecutionSpace, typename ViewType,
std::enable_if_t<ViewType::rank() == 1, std::nullptr_t> = nullptr>
struct complex_view_type {
using value_type = typename ViewType::non_const_value_type;
using float_type = KokkosFFT::Impl::real_type_t<value_type>;
using complex_type = Kokkos::complex<float_type>;
using array_layout_type = typename ViewType::array_layout;
using type = Kokkos::View<complex_type*, array_layout_type, ExecutionSpace>;
};

template <typename ViewType>
auto convert_negative_axis(ViewType, int _axis = -1) {
Expand Down
3 changes: 2 additions & 1 deletion common/unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
add_executable(unit-tests-kokkos-fft-common
Test_Main.cpp
Test_Utils.cpp
Test_Traits.cpp
Test_Normalization.cpp
Test_Transpose.cpp
Test_Layouts.cpp
Expand All @@ -20,4 +21,4 @@ target_link_libraries(unit-tests-kokkos-fft-common PUBLIC common GTest::gtest)

# Enable GoogleTest
include(GoogleTest)
gtest_discover_tests(unit-tests-kokkos-fft-common PROPERTIES DISCOVERY_TIMEOUT 600)
gtest_discover_tests(unit-tests-kokkos-fft-common PROPERTIES DISCOVERY_TIMEOUT 600)
170 changes: 170 additions & 0 deletions common/unit_test/Test_Traits.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception

#include <gtest/gtest.h>
#include "KokkosFFT_traits.hpp"
#include "Test_Utils.hpp"

using real_types = ::testing::Types<float, double, long double>;
using view_types =
::testing::Types<std::pair<float, Kokkos::LayoutLeft>,
std::pair<float, Kokkos::LayoutRight>,
std::pair<float, Kokkos::LayoutStride>,
std::pair<double, Kokkos::LayoutLeft>,
std::pair<double, Kokkos::LayoutRight>,
std::pair<double, Kokkos::LayoutStride>,
std::pair<long double, Kokkos::LayoutLeft>,
std::pair<long double, Kokkos::LayoutRight>,
std::pair<long double, Kokkos::LayoutStride>>;

template <typename T>
struct RealAndComplexTypes : public ::testing::Test {
using real_type = T;
using complex_type = Kokkos::complex<T>;
};

template <typename T>
struct RealAndComplexViewTypes : public ::testing::Test {
using real_type = typename T::first_type;
using complex_type = Kokkos::complex<real_type>;
using layout_type = typename T::second_type;
};

TYPED_TEST_SUITE(RealAndComplexTypes, real_types);
TYPED_TEST_SUITE(RealAndComplexViewTypes, view_types);

// Tests for real type deduction
template <typename RealType, typename ComplexType>
void test_get_real_type() {
using real_type_from_RealType =
KokkosFFT::Impl::base_floating_point_type<RealType>;
using real_type_from_ComplexType =
KokkosFFT::Impl::base_floating_point_type<ComplexType>;

static_assert(std::is_same_v<real_type_from_RealType, RealType>,
"Real type not deduced correctly from real type");
static_assert(std::is_same_v<real_type_from_ComplexType, RealType>,
"Real type not deduced correctly from complex type");
}

// Tests for admissible real types (float or double)
template <typename T>
void test_admissible_real_type() {
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, double>) {
static_assert(KokkosFFT::Impl::is_real_v<T>,
"Real type must be float or double");
} else {
static_assert(!KokkosFFT::Impl::is_real_v<T>,
"Real type must be float or double");
}
}

template <typename T>
void test_admissible_complex_type() {
using real_type = KokkosFFT::Impl::base_floating_point_type<T>;
if constexpr (std::is_same_v<real_type, float> ||
std::is_same_v<real_type, double>) {
static_assert(KokkosFFT::Impl::is_complex_v<T>,
"Complex type must be Kokkos::complex<float> or "
"Kokkos::complex<double>");
} else {
static_assert(!KokkosFFT::Impl::is_complex_v<T>,
"Complex type must be Kokkos::complex<float> or "
"Kokkos::complex<double>");
}
}

TYPED_TEST(RealAndComplexTypes, get_real_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;

test_get_real_type<real_type, complex_type>();
}

TYPED_TEST(RealAndComplexTypes, admissible_real_type) {
using real_type = typename TestFixture::real_type;

test_admissible_real_type<real_type>();
}

TYPED_TEST(RealAndComplexTypes, admissible_complex_type) {
using complex_type = typename TestFixture::complex_type;

test_admissible_complex_type<complex_type>();
}

// Tests for admissible view types
template <typename T, typename LayoutType>
void test_admissible_value_type() {
using ViewType = Kokkos::View<T*, LayoutType>;
using real_type = KokkosFFT::Impl::base_floating_point_type<T>;
if constexpr (std::is_same_v<real_type, float> ||
std::is_same_v<real_type, double>) {
static_assert(KokkosFFT::Impl::is_admissible_value_type_v<ViewType>,
"Real type must be float or double");
} else {
static_assert(!KokkosFFT::Impl::is_admissible_value_type_v<ViewType>,
"Real type must be float or double");
}
}

template <typename T, typename LayoutType>
void test_admissible_layout_type() {
using ViewType = Kokkos::View<T*, LayoutType>;
if constexpr (std::is_same_v<LayoutType, Kokkos::LayoutLeft> ||
std::is_same_v<LayoutType, Kokkos::LayoutRight>) {
static_assert(KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"View Layout must be either LayoutLeft or LayoutRight.");
} else {
static_assert(!KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"View Layout must be either LayoutLeft or LayoutRight.");
}
}

template <typename T, typename LayoutType>
void test_admissible_view_type() {
using ViewType = Kokkos::View<T*, LayoutType>;
using real_type = KokkosFFT::Impl::base_floating_point_type<T>;
if constexpr (
(std::is_same_v<real_type, float> || std::is_same_v<real_type, double>)&&(
std::is_same_v<LayoutType, Kokkos::LayoutLeft> ||
std::is_same_v<LayoutType, Kokkos::LayoutRight>)) {
static_assert(KokkosFFT::Impl::is_admissible_view_v<ViewType>,
"View value type must be float, double, "
"Kokkos::Complex<float>, Kokkos::Complex<double>. Layout "
"must be either LayoutLeft or LayoutRight.");
} else {
static_assert(!KokkosFFT::Impl::is_admissible_view_v<ViewType>,
"View value type must be float, double, "
"Kokkos::Complex<float>, Kokkos::Complex<double>. Layout "
"must be either LayoutLeft or LayoutRight.");
}
}

TYPED_TEST(RealAndComplexViewTypes, admissible_value_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;
using layout_type = typename TestFixture::layout_type;

test_admissible_value_type<real_type, layout_type>();
test_admissible_value_type<complex_type, layout_type>();
}

TYPED_TEST(RealAndComplexViewTypes, admissible_layout_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;
using layout_type = typename TestFixture::layout_type;

test_admissible_layout_type<real_type, layout_type>();
test_admissible_layout_type<complex_type, layout_type>();
}

TYPED_TEST(RealAndComplexViewTypes, admissible_view_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;
using layout_type = typename TestFixture::layout_type;

test_admissible_view_type<real_type, layout_type>();
test_admissible_view_type<complex_type, layout_type>();
}
Loading
Loading