Skip to content

[Runtime] solve the problem that llama frequently calls autotuner #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions examples/pretune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import argparse

import torch

import flag_gems

device = flag_gems.device

DTYPES = [
torch.float16,
torch.bfloat16,
torch.float32,
]

LLAMA_SHAPES = {
"mm": [
[1024, 4096],
[128256, 4096],
[14336, 4096],
[4096, 14336],
[4096, 4096],
[6144, 4096],
[28672, 4096],
],
}

QWEN_SHAPES = {
"mm": [
[3584, 3584],
[18944, 3584],
[3584, 18944],
[152064, 3584],
[37888, 3584],
],
"addmm": [
[3584, 3584],
[512, 3584],
[4608, 3584],
],
}


MODEL_SHAPES = {
"llama": LLAMA_SHAPES,
"qwen": QWEN_SHAPES,
}


def pretune_mm(max_tokens, shapes):
for dtype in DTYPES:
for M in range(1, max_tokens + 1):
for N, K in shapes:
tensor_a = torch.randn([M, K], dtype=dtype, device=device)
tensor_b = torch.randn([K, N], dtype=dtype, device=device)
flag_gems.mm(tensor_a, tensor_b)


def pretune_addmm(max_tokens, shapes):
for dtype in DTYPES:
for M in range(1, max_tokens + 1):
for N, K in shapes:
tensor_a = torch.randn([M, K], dtype=dtype, device=device)
tensor_b = torch.randn([K, N], dtype=dtype, device=device)
bias = torch.randn([M, N], dtype=dtype, device=device)
flag_gems.addmm(bias, tensor_a, tensor_b)


OPERATORS = {
"mm": pretune_mm,
"addmm": pretune_addmm,
}


def args_parser():
parser = argparse.ArgumentParser(
description="pretune for gemm",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model",
type=str,
required=False,
default="llama",
help="model name",
)
parser.add_argument(
"--max_tokens",
type=int,
required=False,
default=100,
help="max tokens",
)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = args_parser()
model = MODEL_SHAPES.get(args.model)
max_tokens = args.max_tokens
if not model:
exit(0)
for op, func in OPERATORS.items():
shapes = model.get(op)
if not shapes:
continue
func(max_tokens, shapes)
13 changes: 5 additions & 8 deletions src/flag_gems/ops/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import triton
import triton.language as tl

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle
Expand Down Expand Up @@ -46,20 +45,18 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr
tl.store(out, out_val)


def heur_block_m(args):
return 4 if args["M"] < 4096 else 8


def heur_block_n(args):
return min(4096, triton.next_power_of_2(args["N"]))


@libentry()
@triton.autotune(
configs=runtime.get_triton_config("argmax"),
key=[
"M",
"N",
],
)
@triton.heuristics(
{
"BLOCK_M": heur_block_m,
"BLOCK_N": heur_block_n,
}
)
Expand Down
26 changes: 20 additions & 6 deletions src/flag_gems/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from flag_gems.utils.code_cache import cache_dir
from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.shape_utils import restride_dim

Expand All @@ -32,16 +32,30 @@ def generate_gather_kernel(
# make the inlined function visible in the context
code.newline()

# the autotune function
code.writeline("def heur_block_m(args):")
with code.indent():
code.writeline(
"return min(4, triton.next_power_of_2(triton.cdiv(args['N'], 2048)))"
)

code.newline()
code.writeline("def heur_block_n(args):")
with code.indent():
code.writeline("return min(2048, triton.next_power_of_2(args['N']))")

code.newline()
code.newline()

# the decorators
code.writeline("@libentry()")
code.writeline(
'@triton.autotune(configs=runtime.get_triton_config("gather"), key=["M", "N"])'
)
code.writeline("@triton.heuristics(")
with code.indent():
code.writeline("{")
with code.indent():
code.writeline('"BLOCK_M": heur_block_m,')
code.writeline('"BLOCK_N": heur_block_n,')
code.writeline("}")
code.writeline(")")
code.writeline("@triton.jit")

# signature
Expand Down Expand Up @@ -217,7 +231,7 @@ def __call__(self, *args, **kwargs):

file_name = f"gather_rank_{key}_pid_{self.pid}.py"

with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
f.write(code.getvalue())

# load
Expand Down
17 changes: 15 additions & 2 deletions src/flag_gems/ops/index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,26 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle


def heur_block_m(args):
return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))


def heur_block_n(args):
m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
return max(m, 16)


@libentry()
@triton.autotune(configs=runtime.get_triton_config("index_select"), key=["M", "N"])
@triton.heuristics(
{
"BLOCK_M": heur_block_m,
"BLOCK_N": heur_block_n,
}
)
@triton.jit
def index_select_kernel(
inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
Expand Down
6 changes: 3 additions & 3 deletions src/flag_gems/ops/log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@


@libentry()
@triton.autotune(configs=runtime.get_triton_config("log_softmax"), key=["M", "N"])
@triton.jit
def log_softmax_kernel(
output_ptr,
input_ptr,
M,
N,
K,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_M: tl.constexpr = 8,
BLOCK_N: tl.constexpr = 256,
):
pid_m = tle.program_id(0)
pid_k = tle.program_id(1)
Expand Down Expand Up @@ -122,6 +121,7 @@ def forward(ctx, x, dim, dtype):
M,
N,
K,
num_warps=8,
)
ctx.save_for_backward(out)
ctx.dim = dim
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import libentry, libtuner
from ..utils import triton_lang_extension as tle


Expand All @@ -15,7 +15,7 @@ def heur_even_k(args):


@libentry()
@triton.autotune(
@libtuner(
configs=runtime.get_triton_config("mm"),
key=["M", "N", "K"],
)
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from flag_gems.utils.code_cache import cache_dir
from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace


Expand Down Expand Up @@ -424,7 +424,7 @@ def __call__(self, *args, **kwargs):

file_name = f"constant_pad_rank_{key}_pid_{self.pid}.py"

with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
f.write(code.getvalue())

# load
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from flag_gems.utils.code_cache import cache_dir
from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace


Expand Down Expand Up @@ -437,7 +437,7 @@ def __call__(self, x, sizes):

file_name = f"repeat_rank_{key}_pid_{self.pid}.py"

with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
f.write(code.getvalue())

# load
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from flag_gems.utils.code_cache import cache_dir
from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace


Expand Down Expand Up @@ -248,7 +248,7 @@ def __call__(self, *args, **kwargs):

file_name = f"scatter_rank_{key}_pid_{self.pid}.py"

with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
f.write(code.getvalue())

# load
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from flag_gems.utils.code_cache import cache_dir
from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace


Expand Down Expand Up @@ -437,7 +437,7 @@ def __call__(self, x, dims):

file_name = f"tile_rank_{key}_pid_{self.pid}.py"

with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
f.write(code.getvalue())

# load
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .libentry import libentry
from .libentry import libentry, libtuner
from .pointwise_dynamic import pointwise_dynamic
from .shape_utils import (
broadcastable,
Expand All @@ -10,6 +10,7 @@

__all__ = [
"libentry",
"libtuner",
"pointwise_dynamic",
"dim_compress",
"restride_dim",
Expand Down
12 changes: 12 additions & 0 deletions src/flag_gems/utils/code_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ def cache_dir() -> Path:
return _cache_dir


def code_cache_dir() -> Path:
_code_cache_dir = cache_dir() / "code_cache"
os.makedirs(_code_cache_dir, exist_ok=True)
return _code_cache_dir


def config_cache_dir() -> Path:
_config_cache_dir = cache_dir() / "config_cache"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for name: "tunning_cache"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference?

os.makedirs(_config_cache_dir, exist_ok=True)
return _config_cache_dir


def clear_cache():
"""Clear the cache directory for code cache."""
_cache_dir = cache_dir_path()
Expand Down
Loading