Skip to content

Commit 7d57bc1

Browse files
[CK_BUILDER] Forward convolution builder improvements (#3179)
Proposed changes Improve the forward convolution builder implementation and addressed leftover feedback left from PR #3138. Main changes Refactored tests such that they reflect better the builder pattern. The templates and types for the convolution algorithm concepts are created via factory that facilitates programmatic creation of the device op instances. Moved tests into anonymous namespace. The convolution factory had lot of if-else constructs when CK Builder types were converted into CK library types. I had initially trouble in using static_assert in the default branch of switch as the static_assert was evaluated at compile time even for valid types. However, if we change the static_assert to throw "<error message>", it will result in a compile-time error only if the default branch is actually hit. This assumes that the function is consteval. Hence, changed all conversions in the convolution factory to use switch, which is more intuitive. Removed the explicit device op definition from convolution signature and the corresponding predicate file. The device ops are defined by the corresponding concepts. This allowed to remove lot of boilerplate code from the convolution factory. Adde inheritance and convolution algorithm specialization to handle device ops that are specialization of a more generic ones. The large tensor support is more naturally expressed by this pattern. Added support for the FP8 data type. * WIP: Builder for expected test results. * Improve ckb fwd conv instance tests. * clang-format * Change if-else statements into switch in conv factory. * Fix clang-formatting. * Removed unnecessary includes. * Added missing copyright. * Remove explicit device op flag from from convolution signature. * Add missing concept. * Fix build. * clang-format * Add test for building conv fwd FP8 instances. * Add missing header to instance traits. * Clean-up recently added instances. * Introduce inheritance and specialization. * Use builder to build conv algorithm templates and types. * clang-format * Fix conv description tests. --------- Co-authored-by: John Shumway <[email protected]>
1 parent ca2ee0e commit 7d57bc1

26 files changed

+957
-1450
lines changed

experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ concept AccessOrderDescriptor = requires(T t) {
9595
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
9696
};
9797

98-
// No requirements yet for a ConvAlgorithm concept.
98+
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
99+
// concept.
99100
template <typename T>
100101
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
101102

@@ -183,6 +184,12 @@ concept SpecifiesLoopScheduler = requires {
183184
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
184185
};
185186

187+
template <typename T>
188+
concept SpecifiesLargeTensorSupport = requires {
189+
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
190+
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
191+
};
192+
186193
/******************************************** */
187194
/* DL-specific descriptors and requirements */
188195
/******************************************** */
@@ -204,21 +211,9 @@ concept DlThreadClusterDescriptor = requires(T t) {
204211
{ t.n1_xs } -> std::convertible_to<std::array<size_t, 2>>;
205212
};
206213

207-
// Concept for DL block transfer K0_M0_M1_K1 format
208-
template <typename T>
209-
concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
210-
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
211-
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
212-
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
213-
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
214-
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
215-
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
216-
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
217-
};
218-
219-
// Concept for DL block transfer K0_N0_N1_K1 format
214+
// Concept for DL block transfer
220215
template <typename T>
221-
concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
216+
concept DlBlockTransferDescriptor = requires(T t) {
222217
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
223218
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
224219
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
@@ -228,9 +223,9 @@ concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
228223
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
229224
};
230225

231-
// Concept for DL C thread transfer
226+
// Concept for DL epilogue
232227
template <typename T>
233-
concept DlCThreadTransferDescriptor = requires(T t) {
228+
concept DlEpilogueDescriptor = requires(T t) {
234229
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
235230
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
236231
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
@@ -239,31 +234,63 @@ concept DlCThreadTransferDescriptor = requires(T t) {
239234
// Concept to check if algorithm specifies DL thread config
240235
template <typename T>
241236
concept SpecifiesDlThreadConfig = requires {
242-
{ T::dl_thread_config } -> DlThreadConfigDescriptor;
237+
{ T::thread_config } -> DlThreadConfigDescriptor;
243238
};
244239

245240
// Concept to check if algorithm specifies DL thread cluster
246241
template <typename T>
247242
concept SpecifiesDlThreadCluster = requires {
248-
{ T::dl_thread_cluster } -> DlThreadClusterDescriptor;
243+
{ T::thread_cluster } -> DlThreadClusterDescriptor;
249244
};
250245

251-
// Concept to check if algorithm specifies DL A block transfer
246+
// Concept to check if algorithm specifies DL block transfer
252247
template <typename T>
253-
concept SpecifiesDlBlockTransferA = requires {
254-
{ T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor;
248+
concept SpecifiesDlBlockTransfer = requires {
249+
{ T::block_transfer_a } -> DlBlockTransferDescriptor;
250+
{ T::block_transfer_b } -> DlBlockTransferDescriptor;
255251
};
256252

257-
// Concept to check if algorithm specifies DL B block transfer
253+
// Concept to check if algorithm specifies DL C thread transfer
258254
template <typename T>
259-
concept SpecifiesDlBlockTransferB = requires {
260-
{ T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor;
255+
concept SpecifiesDlEpilogue = requires {
256+
{ T::epilogue_c } -> DlEpilogueDescriptor;
261257
};
262258

263-
// Concept to check if algorithm specifies DL C thread transfer
259+
/******************************************** */
260+
/* Concepts for the different device ops */
261+
/******************************************** */
262+
264263
template <typename T>
265-
concept SpecifiesDlCThreadTransfer = requires {
266-
{ T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor;
267-
};
264+
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
265+
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
266+
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
267+
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
268+
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
269+
270+
template <typename T>
271+
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
272+
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
273+
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
274+
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
275+
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
276+
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
277+
278+
template <typename T>
279+
concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle =
280+
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
281+
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
282+
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
283+
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
284+
285+
template <typename T>
286+
concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
287+
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConcSpecialization<T> &&
288+
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
289+
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
290+
291+
template <typename T>
292+
concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
293+
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<decltype(T::base_algorithm)> &&
294+
SpecifiesLargeTensorSupport<T>;
268295

269296
} // namespace ck_tile::builder

0 commit comments

Comments
 (0)