Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add traits to help type checks #117

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions common/src/KokkosFFT_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception

#ifndef KOKKOSFFT_TRAITS_HPP
#define KOKKOSFFT_TRAITS_HPP

#include <Kokkos_Core.hpp>
#include <vector>
#include <set>
#include <algorithm>
#include <numeric>

namespace KokkosFFT {
namespace Impl {
template <typename T>
struct real_type {
using type = T;
};

template <typename T>
struct real_type<Kokkos::complex<T>> {
using type = T;
};

template <typename T>
using real_type_t = typename real_type<T>::type;
yasahi-hpc marked this conversation as resolved.
Show resolved Hide resolved

template <typename T, typename Enable = void>
struct is_real : std::false_type {};

template <typename T>
struct is_real<
T, std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>>>
: std::true_type {};

template <typename T>
inline constexpr bool is_real_v = is_real<T>::value;

template <typename T, typename Enable = void>
struct is_complex : std::false_type {};

template <typename T>
struct is_complex<
Kokkos::complex<T>,
std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>>>
: std::true_type {};

template <typename T>
inline constexpr bool is_complex_v = is_complex<T>::value;

// is value type admissible for KokkosFFT
template <typename T, typename Enable = void>
struct is_admissible_value_type : std::false_type {};

template <typename T>
struct is_admissible_value_type<
T, std::enable_if_t<is_real_v<T> || is_complex_v<T>>> : std::true_type {};

template <typename T>
struct is_admissible_value_type<
T, std::enable_if_t<Kokkos::is_view<T>::value &&
(is_real_v<typename T::non_const_value_type> ||
is_complex_v<typename T::non_const_value_type>)>>
: std::true_type {};

template <typename T>
inline constexpr bool is_admissible_value_type_v =
is_admissible_value_type<T>::value;

// is layout admissible for KokkosFFT
template <typename ViewType, typename Enable = void>
struct is_layout_left_or_right : std::false_type {};

template <typename ViewType>
struct is_layout_left_or_right<
ViewType,
std::enable_if_t<
Kokkos::is_view<ViewType>::value &&
(std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutLeft> ||
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>)>>
: std::true_type {};

template <typename ViewType>
inline constexpr bool is_layout_left_or_right_v =
is_layout_left_or_right<ViewType>::value;

// is view admissible for KokkosFFT
template <typename ViewType, typename Enable = void>
struct is_admissible_view : std::false_type {};

template <typename ViewType>
struct is_admissible_view<
ViewType, std::enable_if_t<Kokkos::is_view<ViewType>::value &&
is_layout_left_or_right_v<ViewType> &&
is_admissible_value_type_v<ViewType>>>
: std::true_type {};

template <typename ViewType>
inline constexpr bool is_admissible_view_v =
is_admissible_view<ViewType>::value;

template <typename T>
struct managable_view_type {
using type = Kokkos::View<typename T::data_type, typename T::array_layout,
typename T::memory_space,
Kokkos::MemoryTraits<T::memory_traits::impl_value &
~unsigned(Kokkos::Unmanaged)>>;
};

template <typename ExecutionSpace, typename ViewType,
std::enable_if_t<ViewType::rank() == 1, std::nullptr_t> = nullptr>
struct complex_view_type {
using value_type = typename ViewType::non_const_value_type;
using float_type = KokkosFFT::Impl::real_type_t<value_type>;
using complex_type = Kokkos::complex<float_type>;
using array_layout_type = typename ViewType::array_layout;
using type = Kokkos::View<complex_type*, array_layout_type, ExecutionSpace>;
};

} // namespace Impl
} // namespace KokkosFFT

#endif
52 changes: 1 addition & 51 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,60 +10,10 @@
#include <set>
#include <algorithm>
#include <numeric>
#include "KokkosFFT_traits.hpp"

namespace KokkosFFT {
namespace Impl {
template <typename T>
struct real_type {
using type = T;
};

template <typename T>
struct real_type<Kokkos::complex<T>> {
using type = T;
};

template <typename T>
struct managable_view_type {
using type = Kokkos::View<typename T::data_type, typename T::array_layout,
typename T::memory_space,
Kokkos::MemoryTraits<T::memory_traits::impl_value &
~unsigned(Kokkos::Unmanaged)>>;
};

template <typename T>
using real_type_t = typename real_type<T>::type;

template <typename T>
struct is_complex : std::false_type {};

template <typename T>
struct is_complex<Kokkos::complex<T>> : std::true_type {};

template <typename ViewType, typename Enable = void>
struct is_layout_left_or_right : std::false_type {};

template <typename ViewType>
struct is_layout_left_or_right<
ViewType,
std::enable_if_t<
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutLeft> ||
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>>>
: std::true_type {};

template <typename ViewType>
inline constexpr bool is_layout_left_or_right_v =
is_layout_left_or_right<ViewType>::value;

template <typename ExecutionSpace, typename ViewType,
std::enable_if_t<ViewType::rank() == 1, std::nullptr_t> = nullptr>
struct complex_view_type {
using value_type = typename ViewType::non_const_value_type;
using float_type = KokkosFFT::Impl::real_type_t<value_type>;
using complex_type = Kokkos::complex<float_type>;
using array_layout_type = typename ViewType::array_layout;
using type = Kokkos::View<complex_type*, array_layout_type, ExecutionSpace>;
};

template <typename ViewType>
auto convert_negative_axis(ViewType, int _axis = -1) {
Expand Down
3 changes: 2 additions & 1 deletion common/unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
add_executable(unit-tests-kokkos-fft-common
Test_Main.cpp
Test_Utils.cpp
Test_Traits.cpp
Test_Normalization.cpp
Test_Transpose.cpp
Test_Layouts.cpp
Expand All @@ -20,4 +21,4 @@ target_link_libraries(unit-tests-kokkos-fft-common PUBLIC common GTest::gtest)

# Enable GoogleTest
include(GoogleTest)
gtest_discover_tests(unit-tests-kokkos-fft-common PROPERTIES DISCOVERY_TIMEOUT 600)
gtest_discover_tests(unit-tests-kokkos-fft-common PROPERTIES DISCOVERY_TIMEOUT 600)
168 changes: 168 additions & 0 deletions common/unit_test/Test_Traits.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception

#include <gtest/gtest.h>
#include "KokkosFFT_traits.hpp"
#include "Test_Utils.hpp"

using real_types = ::testing::Types<float, double, long double>;
using view_types =
::testing::Types<std::pair<float, Kokkos::LayoutLeft>,
std::pair<float, Kokkos::LayoutRight>,
std::pair<float, Kokkos::LayoutStride>,
std::pair<double, Kokkos::LayoutLeft>,
std::pair<double, Kokkos::LayoutRight>,
std::pair<double, Kokkos::LayoutStride>,
std::pair<long double, Kokkos::LayoutLeft>,
std::pair<long double, Kokkos::LayoutRight>,
std::pair<long double, Kokkos::LayoutStride>>;

template <typename T>
struct RealAndComplexTypes : public ::testing::Test {
using real_type = T;
using complex_type = Kokkos::complex<T>;
};

template <typename T>
struct RealAndComplexViewTypes : public ::testing::Test {
using real_type = typename T::first_type;
using complex_type = Kokkos::complex<real_type>;
using layout_type = typename T::second_type;
};

TYPED_TEST_SUITE(RealAndComplexTypes, real_types);
TYPED_TEST_SUITE(RealAndComplexViewTypes, view_types);

// Tests for real type deduction
template <typename RealType, typename ComplexType>
void test_get_real_type() {
using real_type_from_RealType = KokkosFFT::Impl::real_type_t<RealType>;
using real_type_from_ComplexType = KokkosFFT::Impl::real_type_t<ComplexType>;

static_assert(std::is_same_v<real_type_from_RealType, RealType>,
"Real type not deduced correctly from real type");
static_assert(std::is_same_v<real_type_from_ComplexType, RealType>,
"Real type not deduced correctly from complex type");
}

// Tests for admissible real types (float or double)
template <typename T>
void test_admissible_real_type() {
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, double>) {
static_assert(KokkosFFT::Impl::is_real_v<T>,
"Real type must be float or double");
} else {
static_assert(!KokkosFFT::Impl::is_real_v<T>,
"Real type must be float or double");
}
}

template <typename T>
void test_admissible_complex_type() {
using real_type = KokkosFFT::Impl::real_type_t<T>;
if constexpr (std::is_same_v<real_type, float> ||
std::is_same_v<real_type, double>) {
static_assert(KokkosFFT::Impl::is_complex_v<T>,
"Complex type must be Kokkos::complex<float> or "
"Kokkos::complex<double>");
} else {
static_assert(!KokkosFFT::Impl::is_complex_v<T>,
"Complex type must be Kokkos::complex<float> or "
"Kokkos::complex<double>");
}
}

TYPED_TEST(RealAndComplexTypes, get_real_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;

test_get_real_type<real_type, complex_type>();
}

TYPED_TEST(RealAndComplexTypes, admissible_real_type) {
using real_type = typename TestFixture::real_type;

test_admissible_real_type<real_type>();
}

TYPED_TEST(RealAndComplexTypes, admissible_complex_type) {
using complex_type = typename TestFixture::complex_type;

test_admissible_complex_type<complex_type>();
}

// Tests for admissible view types
template <typename T, typename LayoutType>
void test_admissible_value_type() {
using ViewType = Kokkos::View<T*, LayoutType>;
using real_type = KokkosFFT::Impl::real_type_t<T>;
if constexpr (std::is_same_v<real_type, float> ||
std::is_same_v<real_type, double>) {
static_assert(KokkosFFT::Impl::is_admissible_value_type_v<ViewType>,
"Real type must be float or double");
} else {
static_assert(!KokkosFFT::Impl::is_admissible_value_type_v<ViewType>,
"Real type must be float or double");
}
}

template <typename T, typename LayoutType>
void test_admissible_layout_type() {
using ViewType = Kokkos::View<T*, LayoutType>;
if constexpr (std::is_same_v<LayoutType, Kokkos::LayoutLeft> ||
std::is_same_v<LayoutType, Kokkos::LayoutRight>) {
static_assert(KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"View Layout must be either LayoutLeft or LayoutRight.");
} else {
static_assert(!KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"View Layout must be either LayoutLeft or LayoutRight.");
}
}

template <typename T, typename LayoutType>
void test_admissible_view_type() {
using ViewType = Kokkos::View<T*, LayoutType>;
using real_type = KokkosFFT::Impl::real_type_t<T>;
if constexpr (
(std::is_same_v<real_type, float> || std::is_same_v<real_type, double>)&&(
std::is_same_v<LayoutType, Kokkos::LayoutLeft> ||
std::is_same_v<LayoutType, Kokkos::LayoutRight>)) {
static_assert(KokkosFFT::Impl::is_admissible_view_v<ViewType>,
"View value type must be float, double, "
"Kokkos::Complex<float>, Kokkos::Complex<double>. Layout "
"must be either LayoutLeft or LayoutRight.");
} else {
static_assert(!KokkosFFT::Impl::is_admissible_view_v<ViewType>,
"View value type must be float, double, "
"Kokkos::Complex<float>, Kokkos::Complex<double>. Layout "
"must be either LayoutLeft or LayoutRight.");
}
}

TYPED_TEST(RealAndComplexViewTypes, admissible_value_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;
using layout_type = typename TestFixture::layout_type;

test_admissible_value_type<real_type, layout_type>();
test_admissible_value_type<complex_type, layout_type>();
}

TYPED_TEST(RealAndComplexViewTypes, admissible_layout_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;
using layout_type = typename TestFixture::layout_type;

test_admissible_layout_type<real_type, layout_type>();
test_admissible_layout_type<complex_type, layout_type>();
}

TYPED_TEST(RealAndComplexViewTypes, admissible_view_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;
using layout_type = typename TestFixture::layout_type;

test_admissible_view_type<real_type, layout_type>();
test_admissible_view_type<complex_type, layout_type>();
}
2 changes: 1 addition & 1 deletion common/unit_test/Test_Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#ifndef TEST_TYPES_HPP
#define TEST_TYPES_HPP

#include <Kokkos_Core.hpp>
#include <Kokkos_Complex.hpp>
using execution_space = Kokkos::DefaultExecutionSpace;
template <typename T>
Expand Down
Loading