@@ -567,21 +567,17 @@ int main(int argc, const char** argv)
567567 using LayoutC = cutlass::layout::RowMajor;
568568 using LayoutD = cutlass::layout::RowMajor;
569569
570- using GmemTiledCopyA = XE_2D_U16x32x32_LD_N ;
571- using GmemTiledCopyB = XE_2D_U16x32x32_LD_V ;
570+ using GmemTiledCopyA = void ; // XE_LOAD_2D<16, 32, 32> ;
571+ using GmemTiledCopyB = void ; // XE_LOAD_2D_VNNI<16, 32, 32> ;
572572
573573 // Workgroup-level tile
574574 using TileShape = Shape<_256, _256, _32>;
575575
576- using TiledMma =
577- TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
578- Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>,
579- Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
580- Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;
576+ using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8 , ElementAccumulator, ElementA>>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
581577
582578 constexpr int PipelineStages = 2 ;
583579 // Dispatch to grouped gemm algorithm
584- using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group <PipelineStages>;
580+ using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1StagedGroup <PipelineStages>;
585581 using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
586582
587583 using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
0 commit comments