diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index b7f967bdb8..f940da2724 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -281,8 +281,7 @@ jobs: python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh - python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flexAttnCausal --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-xetla-report.csv --benchmark flexAttnCausal --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG + python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flexAttnCausal --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - name: Run Triton FlexAttention Custom Masks fwd kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }} diff --git a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py index 23b4052474..ce1a4b953a 100644 --- a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py +++ b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py @@ -3,18 +3,19 @@ import os from torch.nn.attention.flex_attention import ( create_block_mask, + create_mask, flex_attention, ) import torch import torch.nn.functional as F + import triton_kernels_benchmark as benchmark_suit -from triton_kernels_benchmark import xetla_kernel torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access # Compile the flex_attention function -flex_attention = torch.compile(flex_attention, dynamic=False) +compiled_flex_attention = torch.compile(flex_attention, dynamic=False) @lru_cache @@ -27,112 +28,127 @@ def causal_mask(_, __, q_idx, kv_idx): return q_idx >= kv_idx +throughput_test = os.getenv('THROUGHPUT_TEST', '0') == '1' +batch_sizes = [16, 32, 64] if throughput_test else [1] + + # Kernel profiling for Backward mode is not working as expected: # For details: https://github.com/pytorch/pytorch/issues/144778 @benchmark_suit.perf_report( benchmark_suit.Benchmark( - x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'], - x_vals=[[z, h, 16384 // z, dhead, causal, mode] - for z in [1, 2, 4, 8, 16, 32] - for (h, dhead) in [(16, 128), (32, 64)] - for causal in [True] - for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] # - + [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] # - + [[z, h, 1024, dhead, True, mode] - for z in [1, 2, 4, 8, 16, 32, 64] - for (h, dhead) in [(8, 128), (32, 96), (4, 128)] - for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]], + x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'], + x_vals= + # Multi-head attention. H_q equals H_kv + # Prefill shapes of Phi3-mini-3.8B + [[z, 32, 32, 1024, 1024, 96, 96, 'fwd'] for z in batch_sizes] + + # Prefill shapes of Deepseek-v3 + [[z, 128, 128, 1024, 1024, 192, 128, 'fwd'] for z in batch_sizes] + + # Append shapes of Phi3-mini-3.8B + [[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, 'fwd'] for z in batch_sizes] + + + # Multi-query attention. H_kv equals 1. + # Append shapes of Deepseek-v3 (Nope) + [ + # RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 133120, Hardware limit: 131072. + # [z, 128, 1, 512, 1024 + 128 + 512, 64, 512, 'fwd'] for z in batch_sizes + ] + + # Append shapes of Deepseek-v3 (Rope) + [] + + + # Grouped-query attention. H_q / H_kv > 1 + # Prefill shapes of Llama-3.1-8B + [[z, 32, 8, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] + + # Prefill shapes of Qwen2-7B + [[z, 28, 4, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] + + # Append shapes of Llama-3.1-8B + [[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] + + # Append shapes of Qwen2-7B + [[z, 28, 4, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] + + + # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k + # Decode shapes of Llama-3.1-8B + [[z, 32, 8, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes] + + # Decode shapes of Phi3-mini-3.8B + [ + # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + # ValueError: Shape element 2 must be a power of 2 + # [z, 32, 32, 1, 1024 + 64, 96, 96, 'fwd'] for z in batch_sizes + ] + + # Decode shapes of Qwen2-7B + [ + # torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2. + # [z, 28, 4, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes + ] + + # Decode shapes of Deepseek-v3 (Nope) + [ + # RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 264192, Hardware limit: 131072. + # [z, 128, 1, 1, 1024, 64, 512, 'fwd'] for z in batch_sizes + ] + + # Decode shapes of Deepseek-v3 (Rope) + [], line_arg='provider', - line_vals=['triton', 'xetla'], - line_names=['Triton', 'XeTLA'], + line_vals=['triton'], + line_names=['Triton'], styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel=['GB/s', 'TFlops'], plot_name='flexAttnCausal-performance', args={}, )) -def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider): - assert MODE in ['fwd', 'bwd'] - assert CAUSAL +def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider): + assert MODE in ['fwd'] dtype = torch.float16 - q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) - k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) - v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True) + q = torch.randn((Z, H_q, N_CTX_q, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd') + k = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd') + v = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_v), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd') sm_scale = 0.125 if MODE == 'bwd': sm_scale = 1.3 quantiles = [0.5, 0.0, 1.0] if provider == 'triton': - kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True} - block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX, N_CTX, device=q.device) - triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, kernel_options=kernel_options - ) + kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD_qk == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True} + block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device='xpu') + triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=( + not H_q == H_kv), kernel_options=kernel_options) + torch_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=not H_q == H_kv) if MODE == 'bwd': triton_o = triton_fn() triton_do = torch.randn_like(triton_o) triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True) - torch_fn = lambda: F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), is_causal=True, scale=sm_scale).to( - torch.float32) - if MODE == 'bwd': - torch_o = torch_fn() - torch_do = torch.randn_like(torch_o) - torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True) - if MODE == 'fwd': - atol = 1e-1 if N_CTX == 16384 else 1e-2 - benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch') - else: - benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch') + + benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch') _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) - elif provider == 'xetla': - xetla_fn = None - if MODE == 'fwd': - module_name = 'flash_attn_causal_True'.lower() - func = getattr(xetla_kernel, module_name) - out = torch.empty_like(q, device='xpu', dtype=dtype) - size_score = Z * H * N_CTX * N_CTX - size_attn_mask = Z * N_CTX * N_CTX - dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8) - bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype) - size_ml = Z * H * N_CTX - m = torch.empty((size_ml, ), device='xpu', dtype=torch.float) - l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) - xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - if MODE == 'bwd': - module_name = 'flash_attn_bwd_causal_True'.lower() - func = getattr(xetla_kernel, module_name) - grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) - bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) - dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8) - out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) - log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) - workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) - grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True) - alpha = sm_scale - dropout_prob = 0 - grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True) - grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True) - grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True) - grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True) - bias_strideB = -1 - bias_strideN = -1 - bias_strideF = -1 - attn_mask_padding = 0 - - xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha, - dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX, - N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) + elif provider == 'onednn': + # OneDNN only supports MHA. + if H_q == H_kv: + mask = create_mask(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device=q.device) + xformers_fn = lambda: F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + if MODE == 'bwd': + xformers_o = xformers_fn() + xformers_do = torch.randn_like(xformers_o) + xformers_fn = lambda: xformers_o.backward(xformers_do, retain_graph=True) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles) + else: + _, min_ms, max_ms, mean, cv = float('nan'), float('nan'), float('nan'), float('nan'), float('nan') else: raise NotImplementedError(f'Unsupported provider {provider}') - tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) - gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) + qk_flops = H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * 2 # mul + add + pv_flops = H_q * N_CTX_q * D_HEAD_v * N_CTX_kv * 2 # mul + add + tflops = lambda mean: Z * (qk_flops + pv_flops) * (1e-12) / (mean * 1e-3) + + q_elems = H_q * N_CTX_q * D_HEAD_qk + k_elems = H_kv * N_CTX_kv * D_HEAD_qk + v_elems = H_kv * N_CTX_kv * D_HEAD_v + gbps = lambda mean: Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3) # float16 2 bytes if MODE == 'bwd': - tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3) - gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3) + tflops = lambda mean: 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * (1e-12) / (mean * 1e-3) + gbps = lambda mean: 2.5 * Z * H_q * (N_CTX_q * D_HEAD_qk + N_CTX_kv * D_HEAD_qk) * 2 * 2 * (1e-9) / (mean * 1e-3 + ) return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv