-
Notifications
You must be signed in to change notification settings - Fork 78
Add more flex attention cases to benchmark. #3928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,18 +3,19 @@ | |
| import os | ||
| from torch.nn.attention.flex_attention import ( | ||
| create_block_mask, | ||
| create_mask, | ||
| flex_attention, | ||
| ) | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| import triton_kernels_benchmark as benchmark_suit | ||
| from triton_kernels_benchmark import xetla_kernel | ||
|
|
||
| torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access | ||
|
|
||
| # Compile the flex_attention function | ||
| flex_attention = torch.compile(flex_attention, dynamic=False) | ||
| compiled_flex_attention = torch.compile(flex_attention, dynamic=False) | ||
|
|
||
|
|
||
| @lru_cache | ||
|
|
@@ -27,112 +28,127 @@ def causal_mask(_, __, q_idx, kv_idx): | |
| return q_idx >= kv_idx | ||
|
|
||
|
|
||
| throughput_test = os.getenv('THROUGHPUT_TEST', '0') == '1' | ||
| batch_sizes = [16, 32, 64] if throughput_test else [1] | ||
|
|
||
|
|
||
| # Kernel profiling for Backward mode is not working as expected: | ||
| # For details: https://github.com/pytorch/pytorch/issues/144778 | ||
| @benchmark_suit.perf_report( | ||
| benchmark_suit.Benchmark( | ||
| x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'], | ||
| x_vals=[[z, h, 16384 // z, dhead, causal, mode] | ||
| for z in [1, 2, 4, 8, 16, 32] | ||
| for (h, dhead) in [(16, 128), (32, 64)] | ||
| for causal in [True] | ||
| for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] # | ||
| + [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] # | ||
| + [[z, h, 1024, dhead, True, mode] | ||
| for z in [1, 2, 4, 8, 16, 32, 64] | ||
| for (h, dhead) in [(8, 128), (32, 96), (4, 128)] | ||
| for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]], | ||
| x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'], | ||
| x_vals= | ||
| # Multi-head attention. H_q equals H_kv | ||
| # Prefill shapes of Phi3-mini-3.8B | ||
| [[z, 32, 32, 1024, 1024, 96, 96, 'fwd'] for z in batch_sizes] + | ||
| # Prefill shapes of Deepseek-v3 | ||
| [[z, 128, 128, 1024, 1024, 192, 128, 'fwd'] for z in batch_sizes] + | ||
| # Append shapes of Phi3-mini-3.8B | ||
| [[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, 'fwd'] for z in batch_sizes] + | ||
|
|
||
| # Multi-query attention. H_kv equals 1. | ||
| # Append shapes of Deepseek-v3 (Nope) | ||
| [ | ||
| # RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 133120, Hardware limit: 131072. | ||
| # [z, 128, 1, 512, 1024 + 128 + 512, 64, 512, 'fwd'] for z in batch_sizes | ||
| ] + | ||
| # Append shapes of Deepseek-v3 (Rope) | ||
| [] + | ||
|
|
||
| # Grouped-query attention. H_q / H_kv > 1 | ||
| # Prefill shapes of Llama-3.1-8B | ||
| [[z, 32, 8, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] + | ||
| # Prefill shapes of Qwen2-7B | ||
| [[z, 28, 4, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] + | ||
| # Append shapes of Llama-3.1-8B | ||
| [[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] + | ||
| # Append shapes of Qwen2-7B | ||
| [[z, 28, 4, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] + | ||
|
|
||
| # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k | ||
| # Decode shapes of Llama-3.1-8B | ||
| [[z, 32, 8, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes] + | ||
| # Decode shapes of Phi3-mini-3.8B | ||
| [ | ||
| # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) | ||
whitneywhtsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # ValueError: Shape element 2 must be a power of 2 | ||
| # [z, 32, 32, 1, 1024 + 64, 96, 96, 'fwd'] for z in batch_sizes | ||
| ] + | ||
| # Decode shapes of Qwen2-7B | ||
| [ | ||
| # torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2. | ||
| # [z, 28, 4, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes | ||
| ] + | ||
| # Decode shapes of Deepseek-v3 (Nope) | ||
| [ | ||
| # RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 264192, Hardware limit: 131072. | ||
| # [z, 128, 1, 1, 1024, 64, 512, 'fwd'] for z in batch_sizes | ||
| ] + | ||
| # Decode shapes of Deepseek-v3 (Rope) | ||
| [], | ||
whitneywhtsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| line_arg='provider', | ||
| line_vals=['triton', 'xetla'], | ||
| line_names=['Triton', 'XeTLA'], | ||
| line_vals=['triton'], | ||
| line_names=['Triton'], | ||
| styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], | ||
| ylabel=['GB/s', 'TFlops'], | ||
| plot_name='flexAttnCausal-performance', | ||
| args={}, | ||
| )) | ||
| def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider): | ||
| assert MODE in ['fwd', 'bwd'] | ||
| assert CAUSAL | ||
| def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider): | ||
| assert MODE in ['fwd'] | ||
whitneywhtsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dtype = torch.float16 | ||
| q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) | ||
| k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) | ||
| v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) | ||
| q = torch.randn((Z, H_q, N_CTX_q, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd') | ||
| k = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd') | ||
| v = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_v), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd') | ||
chengjunlu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sm_scale = 0.125 | ||
| if MODE == 'bwd': | ||
| sm_scale = 1.3 | ||
|
|
||
| quantiles = [0.5, 0.0, 1.0] | ||
| if provider == 'triton': | ||
| kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True} | ||
| block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX, N_CTX, device=q.device) | ||
| triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, kernel_options=kernel_options | ||
| ) | ||
| kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD_qk == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True} | ||
| block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device='xpu') | ||
| triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=( | ||
| not H_q == H_kv), kernel_options=kernel_options) | ||
| torch_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=not H_q == H_kv) | ||
| if MODE == 'bwd': | ||
| triton_o = triton_fn() | ||
| triton_do = torch.randn_like(triton_o) | ||
| triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True) | ||
| torch_fn = lambda: F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), is_causal=True, scale=sm_scale).to( | ||
| torch.float32) | ||
| if MODE == 'bwd': | ||
| torch_o = torch_fn() | ||
| torch_do = torch.randn_like(torch_o) | ||
| torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True) | ||
| if MODE == 'fwd': | ||
| atol = 1e-1 if N_CTX == 16384 else 1e-2 | ||
| benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch') | ||
| else: | ||
| benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch') | ||
|
|
||
| benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch') | ||
| _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) | ||
whitneywhtsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| elif provider == 'xetla': | ||
| xetla_fn = None | ||
| if MODE == 'fwd': | ||
| module_name = 'flash_attn_causal_True'.lower() | ||
| func = getattr(xetla_kernel, module_name) | ||
| out = torch.empty_like(q, device='xpu', dtype=dtype) | ||
| size_score = Z * H * N_CTX * N_CTX | ||
| size_attn_mask = Z * N_CTX * N_CTX | ||
| dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8) | ||
| bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype) | ||
| size_ml = Z * H * N_CTX | ||
| m = torch.empty((size_ml, ), device='xpu', dtype=torch.float) | ||
| l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) | ||
| xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) | ||
| if MODE == 'bwd': | ||
| module_name = 'flash_attn_bwd_causal_True'.lower() | ||
| func = getattr(xetla_kernel, module_name) | ||
| grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) | ||
| bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) | ||
| dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8) | ||
| out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) | ||
| log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) | ||
| workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) | ||
| grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) | ||
| alpha = sm_scale | ||
| dropout_prob = 0 | ||
| grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) | ||
| grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True) | ||
| grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True) | ||
| grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True) | ||
| bias_strideB = -1 | ||
| bias_strideN = -1 | ||
| bias_strideF = -1 | ||
| attn_mask_padding = 0 | ||
|
|
||
| xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha, | ||
| dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX, | ||
| N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding) | ||
| _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) | ||
| elif provider == 'onednn': | ||
| # OneDNN only supports MHA. | ||
whitneywhtsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if H_q == H_kv: | ||
| mask = create_mask(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device=q.device) | ||
| xformers_fn = lambda: F.scaled_dot_product_attention(q, k, v, attn_mask=mask) | ||
| if MODE == 'bwd': | ||
| xformers_o = xformers_fn() | ||
| xformers_do = torch.randn_like(xformers_o) | ||
| xformers_fn = lambda: xformers_o.backward(xformers_do, retain_graph=True) | ||
| _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=10, n_repeat=10, | ||
| quantiles=quantiles) | ||
| else: | ||
| _, min_ms, max_ms, mean, cv = float('nan'), float('nan'), float('nan'), float('nan'), float('nan') | ||
|
|
||
| else: | ||
| raise NotImplementedError(f'Unsupported provider {provider}') | ||
|
|
||
| tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) | ||
| gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) | ||
| qk_flops = H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * 2 # mul + add | ||
| pv_flops = H_q * N_CTX_q * D_HEAD_v * N_CTX_kv * 2 # mul + add | ||
| tflops = lambda mean: Z * (qk_flops + pv_flops) * (1e-12) / (mean * 1e-3) | ||
|
|
||
| q_elems = H_q * N_CTX_q * D_HEAD_qk | ||
| k_elems = H_kv * N_CTX_kv * D_HEAD_qk | ||
| v_elems = H_kv * N_CTX_kv * D_HEAD_v | ||
| gbps = lambda mean: Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3) # float16 2 bytes | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only the GEEM computation and inputs are considered for calculating the tflops and gbps. |
||
|
|
||
| if MODE == 'bwd': | ||
| tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) | ||
| gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) | ||
| tflops = lambda mean: 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * (1e-12) / (mean * 1e-3) | ||
| gbps = lambda mean: 2.5 * Z * H_q * (N_CTX_q * D_HEAD_qk + N_CTX_kv * D_HEAD_qk) * 2 * 2 * (1e-9) / (mean * 1e-3 | ||
| ) | ||
|
|
||
| return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv | ||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.