-
Notifications
You must be signed in to change notification settings - Fork 78
Add more flex attention cases to benchmark. #3928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@liangan1 Please help to comments the expected configuration for the flex attention benchmark. |
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
|
Note the bmk failure "TypeError: benchmark() got an unexpected keyword argument 'B'" |
cc449b5 to
7dca5bc
Compare
76d3c3b to
01b9091
Compare
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Show resolved
Hide resolved
| q_elems = H_q * N_CTX_q * D_HEAD_qk | ||
| k_elems = H_kv * N_CTX_kv * D_HEAD_qk | ||
| v_elems = H_kv * N_CTX_kv * D_HEAD_v | ||
| gbps = lambda mean: Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3) # float16 2 bytes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only the GEEM computation and inputs are considered for calculating the tflops and gbps.
b5216d6 to
5abd839
Compare
|
There is an accuracy issue caused by regression on flex decoding. Need to solve the flex decoding regression issue first. |
benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
Outdated
Show resolved
Hide resolved
|
@chengjunlu the test failed due to diff in result. Did you forget to push a local change ? |
bf4c520 to
9d64d49
Compare
Similar to other problematic shapes, how about we comment out the decode shape in this PR, and fix it in another PR? (93882ff) |
9d64d49 to
93882ff
Compare
Sounds good to me. Let's add the benchmark first. Let's use other issue to track the decoding regression issue. |
55afaa4 to
f582a67
Compare
The accuracy issue has been fixed by #3999. I will rebase this PR after the #3999 merged. |
f582a67 to
04d6f4f
Compare
Signed-off-by: Lu,Chengjun <[email protected]>
Add the flex attention shapes which is used by real model to benchmark for tracking performance.
I commented out 4 cases for now for the reason:
We will investigate the first issue on Triton side later.