-
Notifications
You must be signed in to change notification settings - Fork 107
[Runtime] solve the problem that llama frequently calls autotuner #332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
06bfd2a
[operator] turn autotuner to heuristics
StrongSpoon 9713eb8
[operator] heuristics for gather & index_select
StrongSpoon 0bfaaba
[runtime] libtuner for matmul
StrongSpoon e3846ff
[runtime] store config data in one db
StrongSpoon ac07373
[bugfix] parse key as list instead of tuple
StrongSpoon 916a572
[no ci] update var name
StrongSpoon c584824
[Muti_backend] muti_backend-part_1-framework-and-tune_config (#294)
Galaxy1458 c79c54b
[bugfix] update libtuner to be compatible with triton2
StrongSpoon 40a5df3
[no ci]reformat
StrongSpoon e8eb82e
[operator] update log_softmax
StrongSpoon 85764a4
[pretune] move pretune to ./examples for models
StrongSpoon 3f59476
[format] delete useless print
StrongSpoon e1bb20f
[format] delete unused import
StrongSpoon 11a8479
[format] [no ci] remove useless print
StrongSpoon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import argparse | ||
|
||
import torch | ||
|
||
import flag_gems | ||
|
||
device = flag_gems.device | ||
|
||
DTYPES = [ | ||
torch.float16, | ||
torch.bfloat16, | ||
torch.float32, | ||
] | ||
|
||
LLAMA_SHAPES = { | ||
"mm": [ | ||
[1024, 4096], | ||
[128256, 4096], | ||
[14336, 4096], | ||
[4096, 14336], | ||
[4096, 4096], | ||
[6144, 4096], | ||
[28672, 4096], | ||
], | ||
} | ||
|
||
QWEN_SHAPES = { | ||
"mm": [ | ||
[3584, 3584], | ||
[18944, 3584], | ||
[3584, 18944], | ||
[152064, 3584], | ||
[37888, 3584], | ||
], | ||
"addmm": [ | ||
[3584, 3584], | ||
[512, 3584], | ||
[4608, 3584], | ||
], | ||
} | ||
|
||
|
||
MODEL_SHAPES = { | ||
"llama": LLAMA_SHAPES, | ||
"qwen": QWEN_SHAPES, | ||
} | ||
|
||
|
||
def pretune_mm(max_tokens, shapes): | ||
for dtype in DTYPES: | ||
for M in range(1, max_tokens + 1): | ||
for N, K in shapes: | ||
tensor_a = torch.randn([M, K], dtype=dtype, device=device) | ||
tensor_b = torch.randn([K, N], dtype=dtype, device=device) | ||
flag_gems.mm(tensor_a, tensor_b) | ||
|
||
|
||
def pretune_addmm(max_tokens, shapes): | ||
for dtype in DTYPES: | ||
for M in range(1, max_tokens + 1): | ||
for N, K in shapes: | ||
tensor_a = torch.randn([M, K], dtype=dtype, device=device) | ||
tensor_b = torch.randn([K, N], dtype=dtype, device=device) | ||
bias = torch.randn([M, N], dtype=dtype, device=device) | ||
flag_gems.addmm(bias, tensor_a, tensor_b) | ||
|
||
|
||
OPERATORS = { | ||
"mm": pretune_mm, | ||
"addmm": pretune_addmm, | ||
} | ||
|
||
|
||
def args_parser(): | ||
parser = argparse.ArgumentParser( | ||
description="pretune for gemm", | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
) | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
required=False, | ||
default="llama", | ||
help="model name", | ||
) | ||
parser.add_argument( | ||
"--max_tokens", | ||
type=int, | ||
required=False, | ||
default=100, | ||
help="max tokens", | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = args_parser() | ||
model = MODEL_SHAPES.get(args.model) | ||
max_tokens = args.max_tokens | ||
if not model: | ||
exit(0) | ||
for op, func in OPERATORS.items(): | ||
shapes = model.get(op) | ||
if not shapes: | ||
continue | ||
func(max_tokens, shapes) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion for name: "tunning_cache"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference?