Skip to content

Commit 2c56c1e

Browse files
committed
Rebased to use Sep 24 g++ host compliation fix.
1 parent 1cac8dd commit 2c56c1e

File tree

8 files changed

+28
-27
lines changed

8 files changed

+28
-27
lines changed

applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_bshd.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_,
240240
TiledMmaQK tiled_mma;
241241
// To make all threads in a warp have the same global tensors pass in the
242242
// index of thread 0 in each warp
243-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
243+
auto sg = compat::get_nd_item<1>().get_sub_group();
244244
auto first_thread_in_sg_idx =
245245
sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
246246
auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx);
@@ -336,7 +336,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_,
336336
// Register spill
337337
Tensor gV_ = take<0, 3>(
338338
local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _)));
339-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
339+
auto sg = compat::get_nd_item<1>().get_sub_group();
340340
auto first_thread_in_sg_idx =
341341
sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
342342
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);

applications/flash_attention_v2/collective/xe_flash_attn_sdpa_fwd_bshd_epilogue.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_,
195195
constexpr int FragsM = shape<1>(FragOutLayout{});
196196
constexpr int FragsN = size(select<2, 3>(shape(FragOutLayout{})));
197197

198-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
198+
auto g = compat::get_nd_item<1>().get_sub_group();
199199
auto out_reg = make_tensor(static_cast<decltype(out) &&>(out).data(),
200200
Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});
201201
float tLSE_reg = {-INFINITY};
@@ -260,7 +260,7 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_,
260260
copy(params.xe_store_o, final_out_reg, tOgO);
261261

262262
// Generating the LSE for backward training
263-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
263+
auto sg = compat::get_nd_item<1>().get_sub_group();
264264
int lane_id = static_cast<int>(sg.get_local_linear_id());
265265
int sub_group_id = get_sub_group_id();
266266
const int BLK_M = size(select<0>(TileShapeOutput{}));

applications/flash_attention_v2/collective/xe_flash_attn_sdpa_fwd_bshd_softmax_epilogue.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16,
106106
class FragSum>
107107
CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max,
108108
FragSum &sum) {
109-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
109+
auto g = compat::get_nd_item<1>().get_sub_group();
110110
const auto max_scale = max * params.scale;
111111
CUTLASS_PRAGMA_UNROLL
112112
for (int indx = 0; indx < Vec * FragsM; indx++) {
@@ -123,7 +123,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16,
123123

124124
template <int Vec, int FragsM, int FragsN, class FragSrc, class FragMax>
125125
CUTLASS_DEVICE void reduce_max(FragSrc &src, FragMax &max) {
126-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
126+
auto g = compat::get_nd_item<1>().get_sub_group();
127127
CUTLASS_PRAGMA_UNROLL
128128
for (int indx = 0; indx < Vec * FragsM; indx++) {
129129
auto maxptr = group_broadcast(g, max, indx);
@@ -155,7 +155,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16,
155155
" No. of attention rows per subgroup should be >= 1 MMA Atom "
156156
"worth of rows.");
157157
if (!is_first) {
158-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
158+
auto g = compat::get_nd_item<1>().get_sub_group();
159159
Element max_scale{max * params.scale};
160160
Element exp_scale{
161161
sycl::native::exp2(max_prev * params.scale - max_scale)};

applications/flash_attention_v2/kernel/tile_scheduler_sdpa_fwd_bshd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ struct XeFlashPersistentTileScheduler {
190190
}
191191

192192
template <int Num_SGs> static dim3 get_grid_shape(Params const &params) {
193-
auto queue = syclcompat::get_default_queue();
193+
auto queue = compat::get_default_queue();
194194
auto dev = queue.get_device();
195195
const size_t maxSubgroups =
196196
dev.template get_info<sycl::info::device::max_num_sub_groups>();
Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
197197

198198
template <typename SrcT, typename DstT>
199199
void convert_fp8_to_fp16(const SrcT *d_src, DstT *d_dst, size_t size) {
200-
syclcompat::get_default_queue()
200+
compat::get_default_queue()
201201
.parallel_for(
202202
size,
203203
[=](auto indx) { d_dst[indx] = static_cast<DstT>(d_src[indx]); })
@@ -298,9 +298,9 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
298298
seq_len_qo * seq_len_kv, // batch_stride_S
299299
seq_len_qo * seq_len_kv // batch_stride_S
300300
);
301-
syclcompat::wait();
301+
compat::wait();
302302
std::vector<ElementAccumulator> host_S(block_S.size());
303-
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(),
303+
compat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(),
304304
host_S.size());
305305

306306
// delete this memory as it is no longer needed
@@ -378,7 +378,7 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
378378
cutlass::DeviceAllocation<ElementV_> block_P;
379379
block_P.reset(host_P.size());
380380

381-
syclcompat::memcpy<ElementV_>(block_P.get(), host_P.data(),
381+
compat::memcpy<ElementV_>(block_P.get(), host_P.data(),
382382
host_P.size());
383383

384384
cutlass::TensorRef ref_P(block_P.get(),
@@ -401,12 +401,12 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
401401
seq_len_qo * head_size_vo // batch_stride_O
402402
);
403403

404-
syclcompat::wait();
404+
compat::wait();
405405
// delete this memory as it is no longer needed
406406
block_P.reset();
407407

408408
std::vector<ElementAccumulator> vec_acc(block_acc.size());
409-
syclcompat::memcpy<ElementAccumulator>(
409+
compat::memcpy<ElementAccumulator>(
410410
vec_acc.data(), block_acc.get(), vec_acc.size());
411411

412412
// delete this memory as it is no longer needed
@@ -434,11 +434,11 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
434434
offset_o += seq_len_qo * num_heads_q * head_size_vo;
435435
} // end of batch loop
436436

437-
syclcompat::wait();
438-
syclcompat::memcpy<ElementOutput>(block_ref_O.get(), host_O.data(),
437+
compat::wait();
438+
compat::memcpy<ElementOutput>(block_ref_O.get(), host_O.data(),
439439
host_O.size());
440-
syclcompat::wait();
441-
syclcompat::memcpy<float>(block_ref_LSE.get(), host_LSE.data(),
440+
compat::wait();
441+
compat::memcpy<float>(block_ref_LSE.get(), host_LSE.data(),
442442
host_LSE.size());
443443

444444
// Check if output from CUTLASS kernel and reference kernel are equal or not
@@ -613,29 +613,29 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
613613
// configure smem size and carveout
614614
int smem_size = FMHAPrefillKernel::SharedStorageSize;
615615

616-
const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z);
617-
const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z);
616+
const auto sycl_block = compat::dim3(block.x, block.y, block.z);
617+
const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z);
618618

619619
// Launch parameters depend on whether SYCL compiler supports work-group scratch
620620
// memory extension
621621
#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY)
622-
using namespace syclcompat::experimental;
622+
using namespace compat::experimental;
623623
auto event = launch<cutlass::device_kernel<FMHAPrefillKernel>>(
624624
launch_policy{sycl_grid, sycl_block,
625625
local_mem_size{static_cast<std::size_t>(smem_size)},
626626
kernel_properties{sycl_exp::sub_group_size<
627627
FMHAPrefillKernel::DispatchPolicy::SubgroupSize>}},
628628
params);
629629
#else
630-
syclcompat::experimental::launch_properties launch_props{
630+
compat::experimental::launch_properties launch_props{
631631
sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size),
632632
};
633-
syclcompat::experimental::kernel_properties kernel_props{
633+
compat::experimental::kernel_properties kernel_props{
634634
sycl::ext::oneapi::experimental::sub_group_size<
635635
FMHAPrefillKernel::DispatchPolicy::SubgroupSize>};
636-
syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block,
636+
compat::experimental::launch_policy policy{sycl_grid, sycl_block,
637637
launch_props, kernel_props};
638-
auto event = syclcompat::experimental::launch<
638+
auto event = compat::experimental::launch<
639639
cutlass::device_kernel<FMHAPrefillKernel>>(policy, params);
640640
#endif
641641

@@ -681,7 +681,7 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
681681
// Run the GEMM
682682
run(params);
683683

684-
syclcompat::wait();
684+
compat::wait();
685685

686686
// Verify that the result is correct
687687
bool passed = verify(problem_size, options.is_causal, options.softmax_scale);
@@ -697,7 +697,7 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
697697
for (int i = 0; i < options.iterations; ++i) {
698698
run(params);
699699
}
700-
syclcompat::wait();
700+
compat::wait();
701701
// when seq_len_qo is not equal to seq_len_kv we use bottom up approach
702702
// for the masking. Following changes will adjust the effective_seq_len_kv
703703
// when masking applied for such cases

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ if(CUTLASS_ENABLE_SYCL)
107107
04_bmg_grouped_gemm
108108
05_bmg_gemm_with_epilogues
109109
06_bmg_flash_attention
110+
06a_bmg_flash_attention_sdpa_fwd_bshd
110111
07_bmg_dual_gemm
111112
08_bmg_gemm_f8
112113
09_bmg_grouped_gemm_f8

0 commit comments

Comments
 (0)