Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ jobs:
python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS

source ../../scripts/capture-hw-details.sh
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flexAttnCausal --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-xetla-report.csv --benchmark flexAttnCausal --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flexAttnCausal --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG

- name: Run Triton FlexAttention Custom Masks fwd kernel benchmark
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import os
from torch.nn.attention.flex_attention import (
create_block_mask,
create_mask,
flex_attention,
)

import torch
import torch.nn.functional as F

import triton_kernels_benchmark as benchmark_suit
from triton_kernels_benchmark import xetla_kernel

torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)
compiled_flex_attention = torch.compile(flex_attention, dynamic=False)


@lru_cache
Expand All @@ -27,112 +28,127 @@ def causal_mask(_, __, q_idx, kv_idx):
return q_idx >= kv_idx


throughput_test = os.getenv('THROUGHPUT_TEST', '0') == '1'
batch_sizes = [16, 32, 64] if throughput_test else [1]


# Kernel profiling for Backward mode is not working as expected:
# For details: https://github.com/pytorch/pytorch/issues/144778
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'],
x_vals=[[z, h, 16384 // z, dhead, causal, mode]
for z in [1, 2, 4, 8, 16, 32]
for (h, dhead) in [(16, 128), (32, 64)]
for causal in [True]
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
+ [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
+ [[z, h, 1024, dhead, True, mode]
for z in [1, 2, 4, 8, 16, 32, 64]
for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]],
x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'],
x_vals=
# Multi-head attention. H_q equals H_kv
# Prefill shapes of Phi3-mini-3.8B
[[z, 32, 32, 1024, 1024, 96, 96, 'fwd'] for z in batch_sizes] +
# Prefill shapes of Deepseek-v3
[[z, 128, 128, 1024, 1024, 192, 128, 'fwd'] for z in batch_sizes] +
# Append shapes of Phi3-mini-3.8B
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, 'fwd'] for z in batch_sizes] +

# Multi-query attention. H_kv equals 1.
# Append shapes of Deepseek-v3 (Nope)
[
# RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 133120, Hardware limit: 131072.
# [z, 128, 1, 512, 1024 + 128 + 512, 64, 512, 'fwd'] for z in batch_sizes
] +
# Append shapes of Deepseek-v3 (Rope)
[] +

# Grouped-query attention. H_q / H_kv > 1
# Prefill shapes of Llama-3.1-8B
[[z, 32, 8, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] +
# Prefill shapes of Qwen2-7B
[[z, 28, 4, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] +
# Append shapes of Llama-3.1-8B
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] +
# Append shapes of Qwen2-7B
[[z, 28, 4, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] +

# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
# Decode shapes of Llama-3.1-8B
[[z, 32, 8, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes] +
# Decode shapes of Phi3-mini-3.8B
[
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
# ValueError: Shape element 2 must be a power of 2
# [z, 32, 32, 1, 1024 + 64, 96, 96, 'fwd'] for z in batch_sizes
] +
# Decode shapes of Qwen2-7B
[
# torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
# [z, 28, 4, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes
] +
# Decode shapes of Deepseek-v3 (Nope)
[
# RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 264192, Hardware limit: 131072.
# [z, 128, 1, 1, 1024, 64, 512, 'fwd'] for z in batch_sizes
] +
# Decode shapes of Deepseek-v3 (Rope)
[],
line_arg='provider',
line_vals=['triton', 'xetla'],
line_names=['Triton', 'XeTLA'],
line_vals=['triton'],
line_names=['Triton'],
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel=['GB/s', 'TFlops'],
plot_name='flexAttnCausal-performance',
args={},
))
def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
assert MODE in ['fwd', 'bwd']
assert CAUSAL
def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider):
assert MODE in ['fwd']
dtype = torch.float16
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
q = torch.randn((Z, H_q, N_CTX_q, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
k = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
v = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_v), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
sm_scale = 0.125
if MODE == 'bwd':
sm_scale = 1.3

quantiles = [0.5, 0.0, 1.0]
if provider == 'triton':
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX, N_CTX, device=q.device)
triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, kernel_options=kernel_options
)
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD_qk == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device='xpu')
triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(
not H_q == H_kv), kernel_options=kernel_options)
torch_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=not H_q == H_kv)
if MODE == 'bwd':
triton_o = triton_fn()
triton_do = torch.randn_like(triton_o)
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
torch_fn = lambda: F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), is_causal=True, scale=sm_scale).to(
torch.float32)
if MODE == 'bwd':
torch_o = torch_fn()
torch_do = torch.randn_like(torch_o)
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
if MODE == 'fwd':
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
else:
benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch')

benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)

elif provider == 'xetla':
xetla_fn = None
if MODE == 'fwd':
module_name = 'flash_attn_causal_True'.lower()
func = getattr(xetla_kernel, module_name)
out = torch.empty_like(q, device='xpu', dtype=dtype)
size_score = Z * H * N_CTX * N_CTX
size_attn_mask = Z * N_CTX * N_CTX
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
size_ml = Z * H * N_CTX
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
if MODE == 'bwd':
module_name = 'flash_attn_bwd_causal_True'.lower()
func = getattr(xetla_kernel, module_name)
grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8)
out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
alpha = sm_scale
dropout_prob = 0
grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True)
grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True)
grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True)
bias_strideB = -1
bias_strideN = -1
bias_strideF = -1
attn_mask_padding = 0

xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha,
dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX,
N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
elif provider == 'onednn':
# OneDNN only supports MHA.
if H_q == H_kv:
mask = create_mask(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device=q.device)
xformers_fn = lambda: F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
if MODE == 'bwd':
xformers_o = xformers_fn()
xformers_do = torch.randn_like(xformers_o)
xformers_fn = lambda: xformers_o.backward(xformers_do, retain_graph=True)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
else:
_, min_ms, max_ms, mean, cv = float('nan'), float('nan'), float('nan'), float('nan'), float('nan')

else:
raise NotImplementedError(f'Unsupported provider {provider}')

tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
qk_flops = H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * 2 # mul + add
pv_flops = H_q * N_CTX_q * D_HEAD_v * N_CTX_kv * 2 # mul + add
tflops = lambda mean: Z * (qk_flops + pv_flops) * (1e-12) / (mean * 1e-3)

q_elems = H_q * N_CTX_q * D_HEAD_qk
k_elems = H_kv * N_CTX_kv * D_HEAD_qk
v_elems = H_kv * N_CTX_kv * D_HEAD_v
gbps = lambda mean: Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3) # float16 2 bytes
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the GEEM computation and inputs are considered for calculating the tflops and gbps.


if MODE == 'bwd':
tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
tflops = lambda mean: 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * (1e-12) / (mean * 1e-3)
gbps = lambda mean: 2.5 * Z * H_q * (N_CTX_q * D_HEAD_qk + N_CTX_kv * D_HEAD_qk) * 2 * 2 * (1e-9) / (mean * 1e-3
)

return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv

Expand Down