Skip to content

Commit 39664ab

Browse files
[TEST] Enable test_matmul::test_matmul
1 parent 867da11 commit 39664ab

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

python/test/unit/language/test_matmul.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,11 @@ def fp8e8m0_to_float32(scale):
318318
@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if is_hip_cdna() else [0]))
319319
def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS, device):
320320
if is_xpu():
321-
pytest.skip("FIXME: Fail RuntimeError on XPU")
321+
if (M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim,
322+
NUM_WARPS) == (1024, 512, 256, 128, 64, 128, 1, 0,
323+
4) or (M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim,
324+
NUM_WARPS) == (1024, 512, 256, 128, 64, 128, 3, 0, 4):
325+
pytest.skip("https://github.com/intel/intel-xpu-backend-for-triton/issues/3677")
322326
if is_cuda() and torch.cuda.get_device_capability()[0] < 10:
323327
pytest.skip("Requires compute capability >= 10")
324328
elif is_hip():
@@ -347,9 +351,17 @@ def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS
347351
kernel_kwargs = {}
348352
if is_hip():
349353
kernel_kwargs["matrix_instr_nonkdim"] = nonKDim
350-
out = mxfp_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, a_scale.stride(0), a.stride(0), a.stride(1),
351-
b.stride(0), b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
352-
NUM_STAGES=NUM_STAGES, **kernel_kwargs, num_warps=NUM_WARPS)
354+
355+
try:
356+
out = mxfp_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, a_scale.stride(0), a.stride(0), a.stride(1),
357+
b.stride(0), b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
358+
NUM_STAGES=NUM_STAGES, **kernel_kwargs, num_warps=NUM_WARPS)
359+
except triton.runtime.errors.OutOfResources as err:
360+
if is_xpu() and err.name == "shared memory":
361+
pytest.skip(f"{err}")
362+
else:
363+
raise err
364+
353365
a_scale_f32 = fp8e8m0_to_float32(a_scale)
354366
b_scale_f32 = fp8e8m0_to_float32(b_scale)
355367
a_scale_f32 = a_scale_f32.repeat_interleave(32, dim=1)

0 commit comments

Comments
 (0)