4040
4141import cutlass_library
4242from cutlass_library .library import ConvKind , IteratorAlgorithm , StrideSupport , GroupMode
43+ from cutlass_library .arch_constants import (
44+ INTEL_XE_ARCH_MIN ,
45+ INTEL_XE_ARCH_MAX ,
46+ INTEL_XE12 ,
47+ INTEL_XE20 ,
48+ INTEL_XE35 ,
49+ is_intel_xe_arch
50+ )
4351
4452import cutlass_cppgen
4553from cutlass_cppgen .utils .check import valid_stage_count
4654from cutlass_cppgen .utils .datatypes import td_from_profiler_td , td_from_profiler_op
4755
4856
49- # The value '11' is used to encode Intel PVC GPU in the expected format.
50- _generator_ccs = [11 , 50 , 60 , 61 , 70 , 75 , 80 , 90 ]
57+ # Intel Xe architectures and supported NVIDIA architectures
58+ # Intel Xe: 12 (PVC/Xe-HPC), 20 (BMG/Xe2), 30 (future)
59+ # NVIDIA architectures: 50, 60, 61, 70, 75, 80, 90
60+ _generator_ccs = [INTEL_XE12 , INTEL_XE20 ] #50, 60, 61, 70, 75, 80, 90]
5161
5262class KernelsForDataType :
5363 """
@@ -261,7 +271,12 @@ def __init__(
261271
262272 # Identify the method within CUTLASS generator script that generates kernel
263273 # descriptions for the target CC
264- generate_function_name = "GeneratePVC" if kernel_cc == 11 else "GenerateSM" + str (kernel_cc )
274+ # Intel Xe architectures use GenerateIntelXe, NVIDIA uses GenerateSM{cc}
275+ if is_intel_xe_arch (kernel_cc ):
276+ generate_function_name = "GenerateIntelXe"
277+ else :
278+ generate_function_name = "GenerateSM" + str (kernel_cc )
279+
265280 if not hasattr (cutlass_library .generator , generate_function_name ):
266281 cutlass_cppgen .logger .warning (f"No generator found for architecture { kernel_cc } " )
267282 return
@@ -273,13 +288,20 @@ def __init__(
273288 "--kernels=all" ,
274289 f"--log-level={ logging .getLevelName (cutlass_cppgen .logger .level )} "
275290 ]
276- if self .cc == 11 :
277- args .append ("--architectures=11" )
291+ # For Intel Xe architectures, specify the architecture number
292+ if is_intel_xe_arch (kernel_cc ):
293+ args .append (f"--architectures={ kernel_cc } " )
278294
279295 manifest_args = cutlass_library .generator .define_parser ().parse_args (args )
280296 manifest = cutlass_library .manifest .Manifest (manifest_args )
281- generate_function (manifest , cutlass_cppgen ._nvcc_version )
282-
297+
298+ # For Intel Xe architectures, pass the architecture number to the generator
299+ if is_intel_xe_arch (kernel_cc ):
300+ print (f"Calling { generate_function_name } with arch={ kernel_cc } " )
301+ generate_function (manifest , cutlass_cppgen ._nvcc_version , arch = kernel_cc )
302+ else :
303+ generate_function (manifest , cutlass_cppgen ._nvcc_version )
304+
283305 if operation_kind not in manifest .operations :
284306 # No kernels generated for this architecture, this could be because the CUDA
285307 # toolkit is insufficient to support operations in this CC
@@ -554,8 +576,10 @@ class OptionRegistry:
554576 def __init__ (self , target_cc : int ):
555577 self .registry = {}
556578
557- if target_cc > 90 :
558- raise Exception (f"Unsupported compute capability { target_cc } . The CUTLASS Python interface only supports compute capabilities up to 90." )
579+ # Intel Xe architectures: 12-20 (PVC, BMG, etc.)
580+ # NVIDIA architectures: 50-90
581+ if target_cc > 90 or (not is_intel_xe_arch (target_cc )):
582+ raise Exception (f"Unsupported compute capability { target_cc } . Supported: NVIDIA SM 50-90, Intel Xe 12-20." )
559583
560584 gemm_kinds = [cutlass_library .GemmKind .Universal , cutlass_library .GemmKind .Universal3x ]
561585 operation_kinds = [cutlass_library .OperationKind .Gemm , cutlass_library .OperationKind .Conv2d ]
0 commit comments