-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathtest_cast_matmul.py
139 lines (123 loc) · 5.5 KB
/
test_cast_matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Mixed precision tests for matmul (tl.dot) with cast (tl.to)
issue: https://github.com/triton-lang/triton/issues/2523
TODO: float8 types
"""
import warnings
import pytest
import torch
import triton
import triton.runtime as tr
import triton.language as tl
from triton._internal_testing import is_hip_mi300, is_cuda, is_hip
input_dtypes = ["bfloat16", "float16", "float32", "float64"]
if is_cuda():
input_dtypes += ["int8", "float8_e5m2"]
cc = torch.cuda.get_device_capability(0)
if cc >= (8, 9):
input_dtypes += ["float8_e4m3fn"]
elif is_hip_mi300():
input_dtypes += [
"int8",
"float8_e5m2",
# natively supported on mi300 (see CDNA3 ISA, section 7.2)
"float8_e4m3fnuz",
]
out_dtypes = ["float16", "float32"]
@triton.jit
def matmul_kernel(A, B, C, M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
dot_out_dtype: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
# matrix multiplication
pid = tl.program_id(0)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
for k in range(0, tl.cdiv(K, BLOCK_K)):
k_remaining = K - k * BLOCK_K
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
tl.store(C, acc, mask=mask)
@pytest.mark.parametrize("M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype",
[(M, K, N, BLOCK_K, BLOCK_M, w, x, o) #
for BLOCK_K in [16, 32] #
for BLOCK_M in [16, 64] #
for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] #
for w in input_dtypes
for x in input_dtypes #
for o in out_dtypes])
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype, device):
if x_dtype == w_dtype:
pytest.skip("skip the same input dtype")
if device == "xpu" and "float64" in (w_dtype,
x_dtype) and not tr.driver.active.get_current_target().arch['has_fp64']:
pytest.xfail("float64 not supported on current xpu hardware")
if is_hip() and BLOCK_M == 64 and w_dtype in ["float8_e5m2", "float8_e4m3fnuz"]:
pytest.skip("skip due to bug on HIP path")
x_dtype: torch.dtype = getattr(torch, x_dtype)
w_dtype: torch.dtype = getattr(torch, w_dtype)
def init_tensor(dtype, shape):
if dtype == torch.int8:
return torch.randint(0, 2, shape, device=device, dtype=dtype)
elif dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2):
return torch.randn(shape, device=device, dtype=torch.float16).to(dtype)
else:
return torch.randn(shape, device=device, dtype=dtype)
torch.manual_seed(42)
a = init_tensor(w_dtype, (M, K))
b = init_tensor(x_dtype, (K, N))
torch_dtype = getattr(torch, out_dtype)
triton_dtype = getattr(tl, out_dtype) # <- here force dot_out_dtype
out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype))
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)
# launch kernel
block_m, block_n, block_k = BLOCK_M, 16, BLOCK_K
grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1)
matmul_kernel[grid](
a, b, out_triton, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
GROUP_M=8, #
BLOCK_M=block_m, #
BLOCK_N=block_n, #
BLOCK_K=block_k)
# FIXME: For XPU tests torch can compute reference result on CPU using fp32
# arithmetics for fp16 test. Such reference requires increased tolerance for
# big K values.
if device == "xpu" and out_dtype == "float16" and K > 128:
warnings.warn("FIXME: test case modified, increased tolerance")
torch.testing.assert_close(out_torch, out_triton, atol=2, rtol=0.1)
else:
torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01)