Skip to content

Commit e9c7aa7

Browse files
[Runtime] solve the problem that llama frequently calls autotuner (#332)
* [operator] turn autotuner to heuristics * [operator] heuristics for gather & index_select * [runtime] libtuner for matmul * [runtime] store config data in one db * [bugfix] parse key as list instead of tuple * [no ci] update var name * [Muti_backend] muti_backend-part_1-framework-and-tune_config (#294) * new feature, muti_backend * update auto_tune_module * update auto_tune_module * update auto_tune_module * update __init__ * rebase * fix bug * modifiy auto_tune_config * fix bug * fix bug * update * update * update scatter&gather * fix auto_tune * add gen_torch_device_fn * fix codestyle * fix codestyle * Modify code based on comments * Modify gen_impl with loops instead of recursion * Update code structure * Polish code * update * Polish code * Modify code based on comments * modify based on comment * Modify code based on comments * update * final fix * [bugfix] update libtuner to be compatible with triton2 * [no ci]reformat * [operator] update log_softmax * [pretune] move pretune to ./examples for models * [format] delete useless print * [format] delete unused import * [format] [no ci] remove useless print --------- Co-authored-by: Galaxy1458 <[email protected]>
1 parent e437b36 commit e9c7aa7

File tree

14 files changed

+316
-33
lines changed

14 files changed

+316
-33
lines changed

examples/pretune.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import argparse
2+
3+
import torch
4+
5+
import flag_gems
6+
7+
device = flag_gems.device
8+
9+
DTYPES = [
10+
torch.float16,
11+
torch.bfloat16,
12+
torch.float32,
13+
]
14+
15+
LLAMA_SHAPES = {
16+
"mm": [
17+
[1024, 4096],
18+
[128256, 4096],
19+
[14336, 4096],
20+
[4096, 14336],
21+
[4096, 4096],
22+
[6144, 4096],
23+
[28672, 4096],
24+
],
25+
}
26+
27+
QWEN_SHAPES = {
28+
"mm": [
29+
[3584, 3584],
30+
[18944, 3584],
31+
[3584, 18944],
32+
[152064, 3584],
33+
[37888, 3584],
34+
],
35+
"addmm": [
36+
[3584, 3584],
37+
[512, 3584],
38+
[4608, 3584],
39+
],
40+
}
41+
42+
43+
MODEL_SHAPES = {
44+
"llama": LLAMA_SHAPES,
45+
"qwen": QWEN_SHAPES,
46+
}
47+
48+
49+
def pretune_mm(max_tokens, shapes):
50+
for dtype in DTYPES:
51+
for M in range(1, max_tokens + 1):
52+
for N, K in shapes:
53+
tensor_a = torch.randn([M, K], dtype=dtype, device=device)
54+
tensor_b = torch.randn([K, N], dtype=dtype, device=device)
55+
flag_gems.mm(tensor_a, tensor_b)
56+
57+
58+
def pretune_addmm(max_tokens, shapes):
59+
for dtype in DTYPES:
60+
for M in range(1, max_tokens + 1):
61+
for N, K in shapes:
62+
tensor_a = torch.randn([M, K], dtype=dtype, device=device)
63+
tensor_b = torch.randn([K, N], dtype=dtype, device=device)
64+
bias = torch.randn([M, N], dtype=dtype, device=device)
65+
flag_gems.addmm(bias, tensor_a, tensor_b)
66+
67+
68+
OPERATORS = {
69+
"mm": pretune_mm,
70+
"addmm": pretune_addmm,
71+
}
72+
73+
74+
def args_parser():
75+
parser = argparse.ArgumentParser(
76+
description="pretune for gemm",
77+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
78+
)
79+
parser.add_argument(
80+
"--model",
81+
type=str,
82+
required=False,
83+
default="llama",
84+
help="model name",
85+
)
86+
parser.add_argument(
87+
"--max_tokens",
88+
type=int,
89+
required=False,
90+
default=100,
91+
help="max tokens",
92+
)
93+
args = parser.parse_args()
94+
return args
95+
96+
97+
if __name__ == "__main__":
98+
args = args_parser()
99+
model = MODEL_SHAPES.get(args.model)
100+
max_tokens = args.max_tokens
101+
if not model:
102+
exit(0)
103+
for op, func in OPERATORS.items():
104+
shapes = model.get(op)
105+
if not shapes:
106+
continue
107+
func(max_tokens, shapes)

src/flag_gems/ops/argmax.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import triton
66
import triton.language as tl
77

8-
from .. import runtime
98
from ..runtime import torch_device_fn
109
from ..utils import libentry
1110
from ..utils import triton_lang_extension as tle
@@ -46,20 +45,18 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr
4645
tl.store(out, out_val)
4746

4847

48+
def heur_block_m(args):
49+
return 4 if args["M"] < 4096 else 8
50+
51+
4952
def heur_block_n(args):
5053
return min(4096, triton.next_power_of_2(args["N"]))
5154

5255

5356
@libentry()
54-
@triton.autotune(
55-
configs=runtime.get_triton_config("argmax"),
56-
key=[
57-
"M",
58-
"N",
59-
],
60-
)
6157
@triton.heuristics(
6258
{
59+
"BLOCK_M": heur_block_m,
6360
"BLOCK_N": heur_block_n,
6461
}
6562
)

src/flag_gems/ops/gather.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from flag_gems.utils.code_cache import cache_dir
8+
from flag_gems.utils.code_cache import code_cache_dir
99
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
1010
from flag_gems.utils.shape_utils import restride_dim
1111

@@ -32,16 +32,30 @@ def generate_gather_kernel(
3232
# make the inlined function visible in the context
3333
code.newline()
3434

35-
# the autotune function
35+
code.writeline("def heur_block_m(args):")
36+
with code.indent():
37+
code.writeline(
38+
"return min(4, triton.next_power_of_2(triton.cdiv(args['N'], 2048)))"
39+
)
40+
41+
code.newline()
42+
code.writeline("def heur_block_n(args):")
43+
with code.indent():
44+
code.writeline("return min(2048, triton.next_power_of_2(args['N']))")
3645

3746
code.newline()
3847
code.newline()
3948

4049
# the decorators
4150
code.writeline("@libentry()")
42-
code.writeline(
43-
'@triton.autotune(configs=runtime.get_triton_config("gather"), key=["M", "N"])'
44-
)
51+
code.writeline("@triton.heuristics(")
52+
with code.indent():
53+
code.writeline("{")
54+
with code.indent():
55+
code.writeline('"BLOCK_M": heur_block_m,')
56+
code.writeline('"BLOCK_N": heur_block_n,')
57+
code.writeline("}")
58+
code.writeline(")")
4559
code.writeline("@triton.jit")
4660

4761
# signature
@@ -217,7 +231,7 @@ def __call__(self, *args, **kwargs):
217231

218232
file_name = f"gather_rank_{key}_pid_{self.pid}.py"
219233

220-
with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
234+
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
221235
f.write(code.getvalue())
222236

223237
# load

src/flag_gems/ops/index_select.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,26 @@
44
import triton
55
import triton.language as tl
66

7-
from .. import runtime
87
from ..utils import dim_compress, libentry
98
from ..utils import triton_lang_extension as tle
109

1110

11+
def heur_block_m(args):
12+
return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
13+
14+
15+
def heur_block_n(args):
16+
m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
17+
return max(m, 16)
18+
19+
1220
@libentry()
13-
@triton.autotune(configs=runtime.get_triton_config("index_select"), key=["M", "N"])
21+
@triton.heuristics(
22+
{
23+
"BLOCK_M": heur_block_m,
24+
"BLOCK_N": heur_block_n,
25+
}
26+
)
1427
@triton.jit
1528
def index_select_kernel(
1629
inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr

src/flag_gems/ops/log_softmax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@
1111

1212

1313
@libentry()
14-
@triton.autotune(configs=runtime.get_triton_config("log_softmax"), key=["M", "N"])
1514
@triton.jit
1615
def log_softmax_kernel(
1716
output_ptr,
1817
input_ptr,
1918
M,
2019
N,
2120
K,
22-
BLOCK_M: tl.constexpr,
23-
BLOCK_N: tl.constexpr,
21+
BLOCK_M: tl.constexpr = 8,
22+
BLOCK_N: tl.constexpr = 256,
2423
):
2524
pid_m = tle.program_id(0)
2625
pid_k = tle.program_id(1)
@@ -122,6 +121,7 @@ def forward(ctx, x, dim, dtype):
122121
M,
123122
N,
124123
K,
124+
num_warps=8,
125125
)
126126
ctx.save_for_backward(out)
127127
ctx.dim = dim

src/flag_gems/ops/mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .. import runtime
88
from ..runtime import torch_device_fn
9-
from ..utils import libentry
9+
from ..utils import libentry, libtuner
1010
from ..utils import triton_lang_extension as tle
1111

1212

@@ -15,7 +15,7 @@ def heur_even_k(args):
1515

1616

1717
@libentry()
18-
@triton.autotune(
18+
@libtuner(
1919
configs=runtime.get_triton_config("mm"),
2020
key=["M", "N", "K"],
2121
)

src/flag_gems/ops/pad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from flag_gems.utils.code_cache import cache_dir
8+
from flag_gems.utils.code_cache import code_cache_dir
99
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
1010

1111

@@ -424,7 +424,7 @@ def __call__(self, *args, **kwargs):
424424

425425
file_name = f"constant_pad_rank_{key}_pid_{self.pid}.py"
426426

427-
with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
427+
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
428428
f.write(code.getvalue())
429429

430430
# load

src/flag_gems/ops/repeat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from flag_gems.utils.code_cache import cache_dir
8+
from flag_gems.utils.code_cache import code_cache_dir
99
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
1010

1111

@@ -437,7 +437,7 @@ def __call__(self, x, sizes):
437437

438438
file_name = f"repeat_rank_{key}_pid_{self.pid}.py"
439439

440-
with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
440+
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
441441
f.write(code.getvalue())
442442

443443
# load

src/flag_gems/ops/scatter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from flag_gems.utils.code_cache import cache_dir
8+
from flag_gems.utils.code_cache import code_cache_dir
99
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
1010

1111

@@ -248,7 +248,7 @@ def __call__(self, *args, **kwargs):
248248

249249
file_name = f"scatter_rank_{key}_pid_{self.pid}.py"
250250

251-
with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
251+
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
252252
f.write(code.getvalue())
253253

254254
# load

src/flag_gems/ops/tile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from flag_gems.utils.code_cache import cache_dir
8+
from flag_gems.utils.code_cache import code_cache_dir
99
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
1010

1111

@@ -437,7 +437,7 @@ def __call__(self, x, dims):
437437

438438
file_name = f"tile_rank_{key}_pid_{self.pid}.py"
439439

440-
with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
440+
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
441441
f.write(code.getvalue())
442442

443443
# load

0 commit comments

Comments
 (0)