@@ -318,7 +318,11 @@ def fp8e8m0_to_float32(scale):
318
318
@pytest .mark .parametrize ("nonKDim" , ([0 , 16 , 32 ] if is_hip_cdna () else [0 ]))
319
319
def test_mxfp (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , NUM_STAGES , nonKDim , NUM_WARPS , device ):
320
320
if is_xpu ():
321
- pytest .skip ("FIXME: Fail RuntimeError on XPU" )
321
+ if (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , NUM_STAGES , nonKDim ,
322
+ NUM_WARPS ) == (1024 , 512 , 256 , 128 , 64 , 128 , 1 , 0 ,
323
+ 4 ) or (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , NUM_STAGES , nonKDim ,
324
+ NUM_WARPS ) == (1024 , 512 , 256 , 128 , 64 , 128 , 3 , 0 , 4 ):
325
+ pytest .skip ("https://github.com/intel/intel-xpu-backend-for-triton/issues/3677" )
322
326
if is_cuda () and torch .cuda .get_device_capability ()[0 ] < 10 :
323
327
pytest .skip ("Requires compute capability >= 10" )
324
328
elif is_hip ():
@@ -347,9 +351,17 @@ def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS
347
351
kernel_kwargs = {}
348
352
if is_hip ():
349
353
kernel_kwargs ["matrix_instr_nonkdim" ] = nonKDim
350
- out = mxfp_matmul [grid ](a , b , output , a_scale , b_scale , M , N , K , a_scale .stride (0 ), a .stride (0 ), a .stride (1 ),
351
- b .stride (0 ), b .stride (1 ), output .stride (0 ), output .stride (1 ), BLOCK_M , BLOCK_N , BLOCK_K ,
352
- NUM_STAGES = NUM_STAGES , ** kernel_kwargs , num_warps = NUM_WARPS )
354
+
355
+ try :
356
+ out = mxfp_matmul [grid ](a , b , output , a_scale , b_scale , M , N , K , a_scale .stride (0 ), a .stride (0 ), a .stride (1 ),
357
+ b .stride (0 ), b .stride (1 ), output .stride (0 ), output .stride (1 ), BLOCK_M , BLOCK_N , BLOCK_K ,
358
+ NUM_STAGES = NUM_STAGES , ** kernel_kwargs , num_warps = NUM_WARPS )
359
+ except triton .runtime .errors .OutOfResources as err :
360
+ if is_xpu () and err .name == "shared memory" :
361
+ pytest .skip (f"{ err } " )
362
+ else :
363
+ raise err
364
+
353
365
a_scale_f32 = fp8e8m0_to_float32 (a_scale )
354
366
b_scale_f32 = fp8e8m0_to_float32 (b_scale )
355
367
a_scale_f32 = a_scale_f32 .repeat_interleave (32 , dim = 1 )
0 commit comments