Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
7ba35c9
Add missing copyright statements
SamiAario-AMD Oct 16, 2025
44f44a6
Use ck_tile::host_tensor_descriptor instead of a custom lambda
SamiAario-AMD Oct 16, 2025
4375203
Refactor use of check_data_type in test classes
SamiAario-AMD Oct 20, 2025
780238b
Use TEST_SUITE_NAME with TYPED_TEST_SUITE
SamiAario-AMD Oct 20, 2025
91aa958
Remove an unused namespace
SamiAario-AMD Oct 22, 2025
ae0ea1b
Make dim3 const
SamiAario-AMD Oct 22, 2025
9d1822b
Add BF8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 13, 2025
05c8e76
Add F8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 22, 2025
de30a68
Add BF16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 22, 2025
c074fff
Add BF16 x BF16 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 23, 2025
f77b4db
Add BF8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 23, 2025
742220a
Add F8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 23, 2025
ba89f63
Add F16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 23, 2025
5986915
Skip failing tests of F16 x I4 for CompV3 with K == 2 * K_Tile
SamiAario-AMD Oct 23, 2025
8818847
Add missing precision type combinations to CompV4 from CompV3
SamiAario-AMD Oct 23, 2025
382737d
Move the INT8 tests around for consistency with KernelTypesCompV3Wmma
SamiAario-AMD Oct 23, 2025
8ec762a
Add missing precision type combinations to CompV3Wmma from CompV3
SamiAario-AMD Oct 23, 2025
2182d7c
Remove the basic and universal tests and their dependencies
SamiAario-AMD Oct 23, 2025
b34d16d
On __gfx950__, avoid using transposed loading of A with datatype pk_i…
SamiAario-AMD Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,21 @@ struct GemmPipelineAgBgCrImplBase
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
#if defined(__gfx950__)
static constexpr bool is_a_load_tr = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
static constexpr bool is_b_load_tr = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
// The combination of pk_int4_t and transposed loading causes numerical errors.
// Therefore do not use transposed loading in this case.
static constexpr bool is_a_load_tr = []() {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
}();

static constexpr bool is_b_load_tr = []() {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
}();
#else
static constexpr bool is_a_load_tr = false;
static constexpr bool is_b_load_tr = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,27 @@ template <typename Derived>
struct UniversalGemmBasePolicy
{
#if defined(__gfx950__)
// The combination of pk_int4_t and transposed loading causes numerical errors.
// Therefore do not use transposed loading in this case.
template <typename Problem>
static constexpr bool is_a_load_tr =
std::is_same_v<remove_cvref_t<typename Problem::ALayout>, tensor_layout::gemm::ColumnMajor>;
static constexpr bool is_a_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
tensor_layout::gemm::ColumnMajor>;
}();

template <typename Problem>
static constexpr bool is_b_load_tr =
std::is_same_v<remove_cvref_t<typename Problem::BLayout>, tensor_layout::gemm::RowMajor>;
static constexpr bool is_b_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
tensor_layout::gemm::RowMajor>;
}();
#else
template <typename Problem>
static constexpr bool is_a_load_tr = false;
Expand Down
48 changes: 10 additions & 38 deletions test/ck_tile/gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,6 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
)
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})

if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
add_gtest_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
endif()

if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_gtest_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_basic_fp8 test_gemm_pipeline_basic_fp8.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp)

target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
endif()

if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
add_gtest_executable(test_ck_tile_gemm_pipeline_universal_fp16 test_gemm_pipeline_universal_fp16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE --save-temps -Wno-gnu-line-marker)
add_gtest_executable(test_ck_tile_gemm_pipeline_universal_bf16 test_gemm_pipeline_universal_bf16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_gemm_pipeline_basic_fp16 test_gemm_pipeline_basic_fp16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target ")
endif()

if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
if(GPU_TARGETS MATCHES "gfx94|gfx95")
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
Expand All @@ -71,7 +34,16 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
endif()

if(GPU_TARGETS MATCHES "gfx11|gfx12")
# On Radeon devices, build the WMMA version instead
# On Radeon devices, build the WMMA version instead
# Define architecture macros for compile-time detection
if(GPU_TARGETS MATCHES "gfx12")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DARCH_GFX12)
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DARCH_GFX12)
elseif(GPU_TARGETS MATCHES "gfx11")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DARCH_GFX11)
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DARCH_GFX11)
endif()

add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
Expand Down
13 changes: 0 additions & 13 deletions test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp

This file was deleted.

13 changes: 0 additions & 13 deletions test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp

This file was deleted.

25 changes: 0 additions & 25 deletions test/ck_tile/gemm/test_gemm_pipeline_basic_cases.hpp

This file was deleted.

13 changes: 0 additions & 13 deletions test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp

This file was deleted.

14 changes: 0 additions & 14 deletions test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp

This file was deleted.

Loading