Skip to content

Commit 177c85c

Browse files
authored
Python ops support improvements and test fixes (#595)
Improvements - Add Xe12 and 20 support for cpp gen and remove hardcoded values (from PVC 11) - More combinations for generator.py (covers void for elementC) - Python test fixes Test done: - Python tests - Compilation of generated files - Torch generated file test Note about cpp file changes: - Gemm_operation3x.hpp (cast to ElementCompute done to avoid compiler errors) - xe_epilogue.cpp - Support for ElementC as void type
1 parent 7ab29af commit 177c85c

File tree

18 files changed

+400
-176
lines changed

18 files changed

+400
-176
lines changed

include/cutlass/epilogue/collective/xe_epilogue.hpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,6 @@ class CollectiveEpilogue<
120120

121121
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
122122

123-
using Trait_C = Copy_Traits<GmemTiledCopyC, StrideC>;
124-
using val_layout_load_C = decltype(make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})));
125-
using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom<Trait_C, ElementC>{}, Layout<CopyThreadShape>{}, val_layout_load_C{}));
126-
127123
using Trait_D = Copy_Traits<GmemTiledCopyD, StrideD>;
128124
using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})));
129125
using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom<Trait_D, ElementD>{}, Layout<CopyThreadShape>{}, val_layout_store_D{}));
@@ -132,6 +128,13 @@ class CollectiveEpilogue<
132128
constexpr static bool is_source_supported = not cute::is_void_v<ElementC> && not cute::is_void_v<CopyOpG2R>;
133129
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>;
134130

131+
using NonVoidElementC = conditional_t<is_source_supported, ElementC, ElementD>;
132+
using Trait_C = Copy_Traits<GmemTiledCopyC, StrideC>;
133+
using NonVoidTrait_C = conditional_t<is_source_supported, Trait_C, Trait_D>;
134+
using val_layout_load_C = decltype(make_layout(shape_div(typename NonVoidTrait_C::BlockShape{}, CopyThreadShape{})));
135+
using NonVoidValLayoutLoad_C = conditional_t<is_source_supported, val_layout_load_C, val_layout_store_D>;
136+
using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom<NonVoidTrait_C, NonVoidElementC>{}, Layout<CopyThreadShape>{}, NonVoidValLayoutLoad_C{}));
137+
135138
constexpr static bool is_m_major_C = detail::is_m_major<StrideC>();
136139
constexpr static bool is_m_major_D = detail::is_m_major<StrideD>();
137140

@@ -348,7 +351,7 @@ class CollectiveEpilogue<
348351
auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx);
349352
Tensor tCgD = thread_xe_store_d.partition_D(gD);
350353

351-
Tensor trC = make_tensor<ElementC>(Shape<Int<FragmentSize>>{});
354+
Tensor trC = make_tensor<NonVoidElementC>(Shape<Int<FragmentSize>>{});
352355
Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});
353356

354357
// Because Sm90 uses shared memory, they are not tied to using the same accumulator values
@@ -407,9 +410,12 @@ class CollectiveEpilogue<
407410
CUTLASS_PRAGMA_UNROLL
408411
for (int epi_m = 0; epi_m < FragsM; epi_m++) {
409412
cst_callbacks.begin_loop(epi_m, epi_n);
410-
413+
414+
//avoid evaluating xe_load_c when ElementC is void during compilation
411415
if (is_C_load_needed) {
412-
copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC);
416+
if constexpr (is_source_supported) {
417+
copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC);
418+
}
413419
}
414420

415421
cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed);

python/cutlass_cppgen/backend/evt/passes/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636

3737
# Map from the CC of the kernel to the EVT implementation that the CC targets
3838
cc_map = {
39+
12: 12, # Intel Xe12 PVC
40+
20: 20, # Intel Xe20 BMG
3941
80: 80,
4042
86: 80,
4143
89: 80,

python/cutlass_cppgen/backend/gemm_operation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
cuda = lazy_import("cuda.cuda")
4040
cudart = lazy_import("cuda.cudart")
4141
from cutlass_library import SubstituteTemplate
42+
from cutlass_library.arch_constants import is_intel_xe_arch
4243
import numpy as np
4344

4445
import dpctl
@@ -915,7 +916,7 @@ def get_device_workspace_size(self, arguments):
915916
return 0
916917

917918
def initialize(self):
918-
if self.operation.arch == 11:
919+
if is_intel_xe_arch(self.operation.arch):
919920
return
920921

921922
err, = cuda.cuFuncSetAttribute(
@@ -1318,7 +1319,7 @@ def __init__(self, operation_suffix=""):
13181319

13191320
def emit(self, operation):
13201321
# Support built-in epilogue functors or user-defined functions
1321-
if operation.arch == 11:
1322+
if is_intel_xe_arch(operation.arch):
13221323
stage_count_type = "cutlass::gemm::collective::StageCountAuto"
13231324
elif operation.tile_description.stages is None or operation.tile_description.stages == 0:
13241325
stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>"
@@ -1340,7 +1341,7 @@ def emit(self, operation):
13401341
if operation.tile_description.tile_scheduler is not None:
13411342
tschedule = operation.tile_description.tile_scheduler
13421343

1343-
arch = "cutlass::arch::IntelXe" if operation.arch == 11 else f"cutlass::arch::Sm{operation.arch}"
1344+
arch = f"cutlass::arch::Xe{operation.arch}" if is_intel_xe_arch(operation.arch) else f"cutlass::arch::Sm{operation.arch}"
13441345
values = {
13451346
"operation_name": operation.procedural_name(),
13461347
"operation_suffix": self.operation_suffix,
@@ -1718,10 +1719,15 @@ def epilogue_schedule_name_3x(self):
17181719
def procedural_name(self):
17191720
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
17201721
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
1721-
if self.api == ApiVersion.v3x and (self.arch >= 90 or self.arch == 11):
1722-
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}"
1722+
if self.api == ApiVersion.v3x and (self.arch >= 90 or is_intel_xe_arch(self.arch)):
1723+
arch_prefix="sm"
1724+
if is_intel_xe_arch(self.arch):
1725+
arch_prefix="Xe"
1726+
1727+
kernel_name_template = "cutlass{p}_{sm_or_xe}{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}"
17231728
return kernel_name_template.format(
17241729
p=self.prefix,
1730+
sm_or_xe=arch_prefix,
17251731
ar=self.arch,
17261732
op=opcode_class_name,
17271733
ex=self.extended_name_3x(),

python/cutlass_cppgen/backend/library.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
OpcodeClass,
4747
TileSchedulerType
4848
)
49-
49+
from cutlass_library.arch_constants import is_intel_xe_arch
5050

5151
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
5252
# as the default 3.5.2 on Ubuntu 16.04.
@@ -473,7 +473,7 @@ def api_version(arch, opclass, dtype):
473473
:return: API version to be used in code emission
474474
:rtype: ApiVersion
475475
"""
476-
if opclass == OpcodeClass.TensorOp and arch == 11:
476+
if opclass == OpcodeClass.TensorOp and is_intel_xe_arch(arch):
477477
return ApiVersion.v3x
478478

479479
if (arch >= 90 and

python/cutlass_cppgen/backend/utils/device.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def device_cc(device: int = -1) -> int:
8181
device = cutlass_cppgen.device_id()
8282

8383
if cutlass_cppgen._use_sycl:
84-
# Using '11' to encode Intel PVC as an integer in the expected format.
85-
return 11
84+
# Using '12' to encode Intel PVC as an integer in the expected format.
85+
return 12
8686

8787
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
8888
major = str(deviceProp.major)

python/cutlass_cppgen/library_defaults.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,24 @@
4040

4141
import cutlass_library
4242
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
43+
from cutlass_library.arch_constants import (
44+
INTEL_XE_ARCH_MIN,
45+
INTEL_XE_ARCH_MAX,
46+
INTEL_XE12,
47+
INTEL_XE20,
48+
INTEL_XE35,
49+
is_intel_xe_arch
50+
)
4351

4452
import cutlass_cppgen
4553
from cutlass_cppgen.utils.check import valid_stage_count
4654
from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op
4755

4856

49-
# The value '11' is used to encode Intel PVC GPU in the expected format.
50-
_generator_ccs = [11, 50, 60, 61, 70, 75, 80, 90]
57+
# Intel Xe architectures and supported NVIDIA architectures
58+
# Intel Xe: 12 (PVC/Xe-HPC), 20 (BMG/Xe2), 30 (future)
59+
# NVIDIA architectures: 50, 60, 61, 70, 75, 80, 90
60+
_generator_ccs = [INTEL_XE12, INTEL_XE20] #50, 60, 61, 70, 75, 80, 90]
5161

5262
class KernelsForDataType:
5363
"""
@@ -261,7 +271,12 @@ def __init__(
261271

262272
# Identify the method within CUTLASS generator script that generates kernel
263273
# descriptions for the target CC
264-
generate_function_name = "GeneratePVC" if kernel_cc == 11 else "GenerateSM" + str(kernel_cc)
274+
# Intel Xe architectures use GenerateIntelXe, NVIDIA uses GenerateSM{cc}
275+
if is_intel_xe_arch(kernel_cc):
276+
generate_function_name = "GenerateIntelXe"
277+
else:
278+
generate_function_name = "GenerateSM" + str(kernel_cc)
279+
265280
if not hasattr(cutlass_library.generator, generate_function_name):
266281
cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}")
267282
return
@@ -273,13 +288,20 @@ def __init__(
273288
"--kernels=all",
274289
f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}"
275290
]
276-
if self.cc == 11:
277-
args.append("--architectures=11")
291+
# For Intel Xe architectures, specify the architecture number
292+
if is_intel_xe_arch(kernel_cc):
293+
args.append(f"--architectures={kernel_cc}")
278294

279295
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
280296
manifest = cutlass_library.manifest.Manifest(manifest_args)
281-
generate_function(manifest, cutlass_cppgen._nvcc_version)
282-
297+
298+
# For Intel Xe architectures, pass the architecture number to the generator
299+
if is_intel_xe_arch(kernel_cc):
300+
print(f"Calling {generate_function_name} with arch={kernel_cc}")
301+
generate_function(manifest, cutlass_cppgen._nvcc_version, arch=kernel_cc)
302+
else:
303+
generate_function(manifest, cutlass_cppgen._nvcc_version)
304+
283305
if operation_kind not in manifest.operations:
284306
# No kernels generated for this architecture, this could be because the CUDA
285307
# toolkit is insufficient to support operations in this CC
@@ -554,8 +576,10 @@ class OptionRegistry:
554576
def __init__(self, target_cc: int):
555577
self.registry = {}
556578

557-
if target_cc > 90:
558-
raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to 90.")
579+
# Intel Xe architectures: 12-20 (PVC, BMG, etc.)
580+
# NVIDIA architectures: 50-90
581+
if target_cc > 90 or (not is_intel_xe_arch(target_cc)):
582+
raise Exception(f"Unsupported compute capability {target_cc}. Supported: NVIDIA SM 50-90, Intel Xe 12-20.")
559583

560584
gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x]
561585
operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d]

python/cutlass_cppgen/utils/check.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@
3737
import ctypes
3838

3939
from cutlass_library import DataTypeSize, OperationKind, SharedMemPerCC
40-
40+
from cutlass_library.arch_constants import (
41+
INTEL_XE_ARCH_MIN,
42+
INTEL_XE_ARCH_MAX,
43+
INTEL_XE12,
44+
INTEL_XE20,
45+
INTEL_XE35,
46+
is_intel_xe_arch
47+
)
4148
import cutlass_cppgen
4249
from cutlass_cppgen.backend.library import TileDescription
4350

@@ -117,16 +124,16 @@ def valid_stage_count(
117124
"result in compilation errors if the combination of tile shape, "
118125
"stage count, and shared memory requirement of the epilogue exceeds "
119126
"the available shared memory per SM.")
120-
121-
if kernel_cc == 11:
127+
print(f"KernelCC: {kernel_cc}")
128+
if is_intel_xe_arch(kernel_cc):
122129
if (td.stages is None or td.stages == 0):
123-
# Support for Intel PVC GPU currently does not allow explicit
130+
# Support for Intel Xe GPUs currently does not allow explicit
124131
# specification of the stage count. With None or 0, the
125132
# CollectiveBuilder automatically determines the stage count to use.
126133
return (True, "")
127134
elif verbose:
128-
cutlass.logger.warning(
129-
"Setting an explicit stage count for Intel PVC GPU is currently "
135+
cutlass_cppgen.logger.warning(
136+
"Setting an explicit stage count for Intel Xe GPUs is currently "
130137
"not supported.")
131138

132139
if td.stages <= 0:

python/cutlass_library/arch_constants.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,41 @@
4545
CUDA_ARCH_MIN = 50 # Minimum CUDA architecture (sm_50, sm_60, etc.)
4646

4747
###################################################################################################
48+
# Specific Intel Xe architecture constants
49+
###################################################################################################
50+
# Intel Xe12 - PVC (Ponte Vecchio) HPC architecture
51+
INTEL_XE12 = 12
52+
53+
# Intel Xe20 - BMG (Battlemage) gaming architecture
54+
INTEL_XE20 = 20
55+
56+
# Intel Xe35 - Future architecture placeholder
57+
INTEL_XE35 = 35
58+
59+
###################################################################################################
60+
# Architecture validation helpers
61+
###################################################################################################
62+
def is_intel_xe_arch(arch):
63+
"""Check if the given architecture is an Intel Xe architecture."""
64+
return INTEL_XE_ARCH_MIN <= arch < INTEL_XE_ARCH_MAX
65+
66+
def is_cuda_arch(arch):
67+
"""Check if the given architecture is a CUDA architecture."""
68+
return arch >= CUDA_ARCH_MIN
69+
70+
def get_arch_name(arch):
71+
"""Get a human-readable name for the architecture."""
72+
if arch == INTEL_XE12:
73+
return "Intel Xe12 (PVC)"
74+
elif arch == INTEL_XE20:
75+
return "Intel Xe20 (BMG)"
76+
elif arch == INTEL_XE35:
77+
return "Intel Xe35 (CRI)"
78+
elif is_intel_xe_arch(arch):
79+
return f"Intel Xe{arch}"
80+
elif is_cuda_arch(arch):
81+
return f"CUDA SM{arch}"
82+
else:
83+
return f"Unknown({arch})"
84+
85+
###################################################################################################

python/cutlass_library/gemm_operation.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,16 @@
4848
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
4949
raise ImportError("Disabling attempt to import cutlass_library")
5050
from cutlass_library.library import *
51-
from cutlass_library.arch_constants import INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN
51+
from cutlass_library.arch_constants import (
52+
INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN,
53+
INTEL_XE12, INTEL_XE20, INTEL_XE35
54+
)
5255
except ImportError:
5356
from library import *
54-
from arch_constants import INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN
57+
from arch_constants import (
58+
INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN,
59+
INTEL_XE12, INTEL_XE20, INTEL_XE35
60+
)
5561

5662
_LOGGER = logging.getLogger(__name__)
5763

@@ -392,16 +398,48 @@ def _procedural_name(self):
392398
l = self.layout_name(),
393399
a = str(max(self.A.alignment, self.B.alignment)))
394400
else:
395-
# Intel Xe architectures use xe{cc} naming (e.g., xe20 for BMG, xe12 for PVC)
396-
threadblock = self.tile_description.procedural_name()
397-
return "cutlass{p}_xe{ar}_{op}_{ex}_{tb}_{l}_align{a}".format(
398-
p = self.prefix,
399-
ar = self.arch,
400-
op = opcode_class_name,
401-
ex = self.extended_name(),
402-
tb = threadblock,
403-
l = self.layout_name(),
404-
a = str(max(self.A.alignment, self.B.alignment)))
401+
# Intel Xe architectures use xe{cc} naming with similar detail level as NVIDIA
402+
# Format: cutlass{p}_xe{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}
403+
if self.is_3x:
404+
# Use 3x naming convention with full details like NVIDIA SM90+
405+
tile_shape = self.get_collective_tile_shape()
406+
extended = self.extended_name_3x()
407+
408+
# Add D type suffix if different from C type to distinguish mixed precision variants
409+
if self.D.element != self.C.element:
410+
extended += f"_d{DataTypeNames[self.D.element]}"
411+
412+
kernel_name_template = "cutlass{p}_xe{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}"
413+
return kernel_name_template.format(
414+
p = self.prefix,
415+
ar = self.arch,
416+
op = opcode_class_name,
417+
ex = extended,
418+
ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "",
419+
cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]),
420+
l = self.tile_description.stages,
421+
s = self.layout_name_3x(),
422+
al = str(max(self.A.alignment, self.B.alignment)),
423+
t = TileSchedulerSuffixes[self.tile_scheduler],
424+
k = self.kernel_schedule_name_3x(),
425+
e = self.epilogue_schedule_name_3x())
426+
else:
427+
# Legacy naming for non-3x Intel Xe operations
428+
threadblock = self.tile_description.procedural_name()
429+
extended = self.extended_name()
430+
431+
# Add D type suffix if different from C type to distinguish mixed precision variants
432+
if self.D.element != self.C.element:
433+
extended += f"_d{DataTypeNames[self.D.element]}"
434+
435+
return "cutlass{p}_xe{ar}_{op}_{ex}_{tb}_{l}_align{a}".format(
436+
p = self.prefix,
437+
ar = self.arch,
438+
op = opcode_class_name,
439+
ex = extended,
440+
tb = threadblock,
441+
l = self.layout_name(),
442+
a = str(max(self.A.alignment, self.B.alignment)))
405443

406444
#
407445
def configuration_name(self):

0 commit comments

Comments
 (0)