diff --git a/.github/workflows/pr-perfbench-bot.yml b/.github/workflows/pr-perfbench-bot.yml deleted file mode 100644 index e6954bcc4..000000000 --- a/.github/workflows/pr-perfbench-bot.yml +++ /dev/null @@ -1,88 +0,0 @@ -name: Performance Benchmark Bot - -on: - issue_comment: - types: - - created - -permissions: - contents: read - -concurrency: - group: "${{ github.workflow }}-${{ github.ref }}" - cancel-in-progress: true # always cancel in-progress - -env: - PYTHONDEVMODE: "1" - PYTHONUNBUFFERED: "1" - PYTHONPATH: "" # explicit cleanup - PIP_USER: "" # explicit cleanup - COLUMNS: "100" - FORCE_COLOR: "1" - CLICOLOR_FORCE: "1" - XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated - PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated - -jobs: - perfbench: - name: Benchmark between PR and main - if: | - github.repository_owner == 'tile-ai' && - github.event.issue.pull_request && - (contains(github.event.comment.body, '/performance-report') || contains(github.event.comment.body, '/perf')) - runs-on: [self-hosted, nvidia] - steps: - - name: Checkout repository - uses: actions/checkout@v6 - with: - ref: refs/pull/${{ github.event.issue.number }}/merge - fetch-depth: 0 - submodules: recursive - - - name: Setup Python - uses: actions/setup-python@v6 - with: - python-version: "3.12" - update-environment: true - cache: pip - cache-dependency-path: | - pyproject.toml - requirements*.txt - - - name: Install merged version - run: | - python -m venv tll - source tll/bin/activate - pip install -r requirements-test.txt - pip install . - - - name: Install original version - run: | - echo "Check files to be deleted!" - git clean -dxf -e tll/ - echo "Delete files completed!" - git checkout main - python -m venv tl - source tl/bin/activate - pip install -r requirements-test.txt - pip install . - - - name: Run performance test - id: perfbench - run: | - source tl/bin/activate - python maint/scripts/ci_performance.py - - - name: Post test results as PR comment - uses: actions/github-script@v8 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: '📊 ​**Performance Test Results** (triggered by @' + context.payload.comment.user.login + '):\n\n' + - 'Run listed here: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}\n\n' + - "${{ steps.perfbench.outputs.stdout }}" - }) diff --git a/.github/workflows/pr-regression-test-bot.yml b/.github/workflows/pr-regression-test-bot.yml new file mode 100644 index 000000000..cc17d9750 --- /dev/null +++ b/.github/workflows/pr-regression-test-bot.yml @@ -0,0 +1,132 @@ +name: Performance Regression Bot + +on: + issue_comment: + types: + - created + +permissions: + contents: read + issues: write + +concurrency: + group: "${{ github.workflow }}-${{ github.ref }}" + cancel-in-progress: true # always cancel in-progress + +env: + PYTHONDEVMODE: "1" + PYTHONUNBUFFERED: "1" + PYTHONPATH: "" # explicit cleanup + PIP_USER: "" # explicit cleanup + COLUMNS: "100" + FORCE_COLOR: "1" + CLICOLOR_FORCE: "1" + XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated + +jobs: + pr-regression: + name: Performance regression test between PR and main + if: | + github.repository_owner == 'tile-ai' && + github.event.issue.pull_request && + (contains(github.event.comment.body, '/performance-report') || contains(github.event.comment.body, '/perf')) + runs-on: [self-hosted, nvidia] + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ref: refs/pull/${{ github.event.issue.number }}/merge + fetch-depth: 0 + submodules: recursive + + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + update-environment: true + cache: pip + cache-dependency-path: | + pyproject.toml + requirements*.txt + + - name: Clean pip environment + run: | + echo "PIP_CONFIG_FILE=/dev/null" >> $GITHUB_ENV + echo "PIP_NO_USER=1" >> $GITHUB_ENV + echo "PYTHONUSERBASE=" >> $GITHUB_ENV + echo "PIP_DISABLE_PIP_VERSION_CHECK=1" >> $GITHUB_ENV + echo "PIP_CACHE_DIR=$(mktemp -d)" >> $GITHUB_ENV + + - name: Install PR version (new) + run: | + python -m venv new + source new/bin/activate + pip install --no-user -r requirements-test.txt + pip install --no-user . + + - name: Clean pip environment + run: | + echo "PIP_CONFIG_FILE=/dev/null" >> $GITHUB_ENV + echo "PIP_NO_USER=1" >> $GITHUB_ENV + echo "PYTHONUSERBASE=" >> $GITHUB_ENV + echo "PIP_DISABLE_PIP_VERSION_CHECK=1" >> $GITHUB_ENV + echo "PIP_CACHE_DIR=$(mktemp -d)" >> $GITHUB_ENV + + - name: Install main version (old) + run: | + echo "Check files to be deleted!" + git clean -dxf -e new/ + echo "Delete files completed!" + git checkout main + python -m venv old + source old/bin/activate + pip install --no-user -r requirements-test.txt + pip install --no-user . + + - name: Run performance regression test + run: | + source new/bin/activate + OLD_PYTHON=./old/bin/python NEW_PYTHON=./new/bin/python \ + PERF_REGRESSION_MD=regression_result.md PERF_REGRESSION_PNG=regression_result.png \ + python ./maint/scripts/test_perf_regression.py + + - name: Read markdown table + id: read_md + run: | + echo "content<> $GITHUB_OUTPUT + cat regression_result.md >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + + - name: Upload benchmark image as artifact + uses: actions/upload-artifact@v4 + with: + name: perfbench-${{ github.run_id }} + path: regression_result.png + + - name: Post test results as PR comment + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const md = `${{ steps.read_md.outputs.content }}`; + const runUrl = + `${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}`; + const body = + 'Performance Benchmark Report\n' + + '============================\n\n' + + `Triggered by: @${context.payload.comment.user.login}\n` + + `Workflow run: ${runUrl}\n\n` + + 'Results\n' + + '-------\n\n' + + md + '\n\n' + + 'Artifacts\n' + + '---------\n\n' + + '- regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.\n'; + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body + }); \ No newline at end of file diff --git a/bench.md b/bench.md new file mode 100644 index 000000000..84824f4ca --- /dev/null +++ b/bench.md @@ -0,0 +1,72 @@ +| File | Original Latency | Current Latency | Speedup | +|----------------------------------------------------------|--------------------|-------------------|-----------| +| example_dequant_gemv_fp16xint4 | 0.0052611 | 0.00556904 | 0.944705 | +| example_mha_fwd_bhsd | 0.009312 | 0.009664 | 0.963576 | +| example_mha_fwd_bshd_wgmma_pipelined | 0.014688 | 0.015104 | 0.972458 | +| example_per_token_cast_to_fp8 | 0.00885861 | 0.0091083 | 0.972587 | +| example_dequant_gemm_bf16_mxfp4_hopper | 0.011872 | 0.012096 | 0.981481 | +| example_gqa_sink_fwd_bhsd_wgmma_pipelined | 0.013135 | 0.0133685 | 0.982534 | +| example_tilelang_nsa_decode | 0.00833827 | 0.00848195 | 0.98306 | +| example_mha_bwd_bshd | 0.0323905 | 0.032925 | 0.983766 | +| example_dequant_gemm_fp4_hopper | 0.010944 | 0.011104 | 0.985591 | +| example_mha_sink_fwd_bhsd | 0.0112312 | 0.0113914 | 0.985937 | +| example_mha_sink_bwd_bhsd_sliding_window | 0.0250769 | 0.0253728 | 0.988338 | +| example_mha_fwd_bhsd_wgmma_pipelined | 0.015296 | 0.015456 | 0.989648 | +| example_warp_specialize_gemm_copy_0_gemm_1 | 0.03088 | 0.031168 | 0.99076 | +| example_gemm_intrinsics | 0.027712 | 0.027904 | 0.993119 | +| sparse_mla_fwd_pipelined | 0.0984533 | 0.0987996 | 0.996495 | +| example_vertical_slash_sparse_attn | 0.00148107 | 0.00148577 | 0.996837 | +| example_tilelang_gemm_fp8_intrinsic | 0.005984 | 0.006 | 0.997333 | +| example_mha_sink_fwd_bhsd_wgmma_pipelined | 0.0149988 | 0.0150285 | 0.998024 | +| example_gemm | 0.017281 | 0.017312 | 0.998209 | +| example_fusedmoe_tilelang | 0.217889 | 0.218217 | 0.998497 | +| example_tilelang_gemm_fp8 | 0.01376 | 0.013776 | 0.998839 | +| example_mha_sink_fwd_bhsd_sliding_window | 0.0113744 | 0.0113844 | 0.999122 | +| example_mha_sink_bwd_bhsd | 0.0410323 | 0.0410627 | 0.99926 | +| example_gqa_bwd | 0.0362193 | 0.0362422 | 0.999368 | +| example_group_per_split_token_cast_to_fp8 | 0.0100313 | 0.0100362 | 0.999512 | +| example_gemv | 0.0481227 | 0.048128 | 0.99989 | +| example_linear_attn_bwd | 0.114526 | 0.114538 | 0.999895 | +| sparse_mla_bwd | 0.247906 | 0.247916 | 0.99996 | +| example_gemm_autotune | 0.020544 | 0.020544 | 1 | +| example_dequant_gemm_w4a8 | 0.00624 | 0.00624 | 1 | +| block_sparse_attn_tilelang | 0.0112801 | 0.0112754 | 1.00042 | +| topk_selector | 0.0449657 | 0.044939 | 1.00059 | +| example_tilelang_gemm_splitk_vectorize_atomicadd | 0.0413707 | 0.0413452 | 1.00062 | +| tilelang_example_sparse_tensorcore | 0.0133828 | 0.0133715 | 1.00085 | +| example_convolution_autotune | 0.69616 | 0.695552 | 1.00087 | +| example_linear_attn_fwd | 0.0277008 | 0.0276757 | 1.00091 | +| example_mha_inference | 0.0652448 | 0.0651648 | 1.00123 | +| example_tilelang_gemm_splitk | 0.0402086 | 0.0401445 | 1.0016 | +| example_dequant_groupedgemm_bf16_mxfp4_hopper | 0.0145726 | 0.0145491 | 1.00162 | +| example_mha_fwd_varlen | 1.43004 | 1.42759 | 1.00172 | +| example_tilelang_nsa_fwd | 0.00870593 | 0.00868398 | 1.00253 | +| example_dequant_gemm_bf16_mxfp4_hopper_tma | 0.01168 | 0.011648 | 1.00275 | +| example_topk | 0.010048 | 0.010016 | 1.00319 | +| example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window | 0.0147675 | 0.0147171 | 1.00342 | +| example_gqa_bwd_wgmma_pipelined | 0.0461091 | 0.0459494 | 1.00348 | +| example_mha_bwd_bhsd | 0.0310782 | 0.03097 | 1.00349 | +| example_gqa_sink_bwd_bhsd | 0.0315594 | 0.0314257 | 1.00425 | +| example_blocksparse_gemm | 0.020372 | 0.0202743 | 1.00482 | +| example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window | 0.0132417 | 0.0131672 | 1.00566 | +| example_mla_decode | 0.327616 | 0.32576 | 1.0057 | +| example_gqa_bwd_tma_reduce_varlen | 0.0477759 | 0.0474919 | 1.00598 | +| example_mha_fwd_bshd | 0.023712 | 0.023552 | 1.00679 | +| fp8_lighting_indexer | 0.0293298 | 0.0291277 | 1.00694 | +| example_gqa_decode | 0.043808 | 0.043488 | 1.00736 | +| example_tilelang_gemm_fp8_2xAcc | 0.135392 | 0.134336 | 1.00786 | +| example_gqa_fwd_bshd | 0.056896 | 0.056448 | 1.00794 | +| example_warp_specialize_gemm_barrierpipe_stage2 | 0.031456 | 0.031136 | 1.01028 | +| example_warp_specialize_gemm_copy_1_gemm_0 | 0.030368 | 0.030048 | 1.01065 | +| example_gemm_schedule | 0.0276354 | 0.0273229 | 1.01144 | +| example_convolution | 0.94832 | 0.937536 | 1.0115 | +| example_tilelang_sparse_gqa_decode_varlen_mask | 0.0211596 | 0.0208851 | 1.01314 | +| example_elementwise_add | 0.0193343 | 0.0190617 | 1.0143 | +| example_dynamic | 0.023712 | 0.02336 | 1.01507 | +| example_mha_bwd_bshd_wgmma_pipelined | 0.0221314 | 0.0218026 | 1.01508 | +| sparse_mla_fwd | 0.358282 | 0.351806 | 1.01841 | +| example_gqa_sink_bwd_bhsd_sliding_window | 0.0203378 | 0.0199598 | 1.01894 | +| example_dequant_gemm_bf16_fp4_hopper | 0.010816 | 0.010592 | 1.02115 | +| example_gqa_fwd_bshd_wgmma_pipelined | 0.046112 | 0.04512 | 1.02199 | +| example_tilelang_block_sparse_attn | 0.00847182 | 0.0081466 | 1.03992 | +| example_tilelang_sparse_gqa_decode_varlen_indice | 0.0162253 | 0.013691 | 1.18511 | diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index 5af787a12..00520c9b7 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -498,6 +498,49 @@ def tl_bwd(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + with torch.no_grad(): + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + V = torch.randn_like(K) + sinks = torch.randn(H, dtype=torch_dtype, device="cuda") + dO = torch.randn_like(Q) + fwd = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + O, lse = fwd(Q, K, V, sinks) + + def maybe_contiguous(x): + return x if x.stride(-1) == 1 else x.contiguous() + + do, q, k, v, sinks_c, o = [maybe_contiguous(x) for x in (dO, Q, K, V, sinks, O)] + k_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + Delta = k_prep(o, do) + k_bwd = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + k_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + q_shape = (BATCH, H, N_CTX, D_HEAD) + head_kv = H // groups + kv_shape = (BATCH, head_kv, N_CTX, D_HEAD) + dq = torch.zeros(q_shape, dtype=torch.float32, device="cuda") + dk = torch.zeros(kv_shape, dtype=torch.float32, device="cuda") + dv = torch.zeros(kv_shape, dtype=torch.float32, device="cuda") + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + _ = k_dsink(sinks_c, Delta, lse).sum(0).sum(1) + + def run_kernel_only(): + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + + latency_ms = do_bench(run_kernel_only, warmup=500, rep=10000, backend="cupti") + return latency_ms + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=1, help="Batch size") diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index feb5844f7..d559bc8be 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -316,6 +316,41 @@ def main( print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9)) +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + groups: int = 8, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500, rep=10000, backend="cupti") + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=1, help="batch size") diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index 155c488e6..8fa9a8b1d 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -492,6 +492,46 @@ def tl_bwd(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + BATCH: int = 1, + H: int = 1, + N_CTX: int = 512, + D_HEAD: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + with torch.no_grad(): + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + K = torch.randn_like(Q) + V = torch.randn_like(Q) + sinks = torch.randn(H, dtype=torch_dtype, device=Q.device) + dO = torch.randn_like(Q) + fwd = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size=window_size, dtype=dtype) + O, lse = fwd(Q, K, V, sinks) + + def maybe_contiguous(x): + return x if x.stride(-1) == 1 else x.contiguous() + + do, q, k, v, sinks_c, o = [maybe_contiguous(x) for x in (dO, Q, K, V, sinks, O)] + k_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + Delta = k_prep(o, do) + k_bwd = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) + k_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + shape = (BATCH, H, N_CTX, D_HEAD) + dq = torch.zeros(shape, dtype=torch.float32, device=Q.device) + dk = torch.empty(shape, dtype=torch_dtype, device=Q.device) + dv = torch.empty(shape, dtype=torch_dtype, device=Q.device) + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + _ = k_dsink(sinks_c, Delta, lse).sum(0).sum(1) + + def run_kernel_only(): + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + + latency_ms = do_bench(run_kernel_only, warmup=500, rep=10000, backend="cupti") + return latency_ms + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=1, help="Batch size") diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index 78ac443b2..156b50add 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -300,6 +300,28 @@ def main( print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, heads, seq_q, seq_kv, dim, window_size, block_M=block_M, block_N=block_N, num_stages=num_stages, threads=threads, dtype=dtype + ) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500, rep=10000, backend="cupti") + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="batch size") diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index decdc8f4f..c27ad69a2 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -307,6 +307,29 @@ def main( print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, heads, seq_q, seq_kv, dim, window_size, block_M=block_M, block_N=block_N, num_stages=num_stages, threads=threads, dtype=dtype + ) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500, rep=10000, backend="cupti") + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="batch size") diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index 179483634..b94e602f6 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -1,7 +1,6 @@ # ruff: noqa: E712 import math import torch - import triton import triton.language as tl import torch.nn.functional as F diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index afb4cc888..59196bd89 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -1,8 +1,8 @@ import math import torch - import tilelang import tilelang.language as T +from tilelang.profiler import do_bench import torch.nn.functional as F @@ -217,5 +217,26 @@ def main(): test_topk_sparse_attention() +def run_regression_perf(): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 + BLOCK = 64 + torch.manual_seed(0) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + + def run_kernel_only(): + kernel(q, k, v, block_mask) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index 99418d5fd..75488d7f0 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -8,6 +8,7 @@ import argparse import time import math +from tilelang.profiler import do_bench from heuristic import num_splits_heuristic @@ -535,6 +536,129 @@ def main(args): print(f"Speedup: {kernel_time_fa / kernel_time:.2f}x") +def run_regression_perf(args): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( + args.batch, + args.heads, + args.heads_kv, + args.max_cache_seqlen, + args.dim, + args.dim_v, + ) + sparse_ratio = args.sparse_ratio + block_N = args.block_N + page_block_size = args.page_block_size + num_blocks = args.num_pages + max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N)) + dtype = torch.float16 + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") + max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") + total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + available_blocks = list(range(total_blocks_needed)) + import random + + random.seed(42) + random.shuffle(available_blocks) + block_assignment = {} + block_idx_counter = 0 + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + for block_idx in range(num_blocks_needed): + physical_block_idx = available_blocks[block_idx_counter] + block_table[seq_idx, block_idx] = physical_block_idx + block_assignment[(seq_idx, block_idx)] = physical_block_idx + block_idx_counter += 1 + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + for block_idx in range(num_blocks_needed): + physical_block_idx = block_assignment[(seq_idx, block_idx)] + start_token = block_idx * page_block_size + end_token = min(start_token + page_block_size, seq_len) + actual_block_size = end_token - start_token + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :] + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_tile = int(math.ceil(seq_len / block_N)) + if sparse_ratio == 0.0: + selected_blocks = min(num_tile, max_selected_blocks) + for head_idx in range(heads_kv): + for i in range(selected_blocks): + block_indices[seq_idx, head_idx, i] = num_tile - 1 - i + for i in range(selected_blocks, max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + else: + num_selected = int(num_tile * (1.0 - sparse_ratio)) + num_selected = max(1, min(num_selected, max_selected_blocks)) + all_blocks = list(range(num_tile)) + for head_idx in range(heads_kv): + selected_blocks = [] + recent_blocks = 1 + selected_blocks.append(num_tile - 1) + if num_selected > recent_blocks: + remaining_blocks = [b for b in all_blocks if b not in selected_blocks] + if remaining_blocks: + import random + + random.seed(42) + additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks))) + selected_blocks.extend(additional_blocks) + + selected_blocks.sort(reverse=True) + + for i in range(len(selected_blocks)): + block_indices[seq_idx, head_idx, i] = selected_blocks[i] + for i in range(len(selected_blocks), max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + + sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) + kernel = sparse_attn.kernel + batch = sparse_attn.batch + heads = sparse_attn.heads + heads_kv = sparse_attn.heads_kv + dim_v = sparse_attn.dim_v + dim = sparse_attn.dim + block_size = sparse_attn.block_N + max_selected_blocks = block_indices.shape[-1] + + num_m_blocks = 1 * (heads // heads_kv + sparse_attn.block_H - 1) // sparse_attn.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + + num_sm = sparse_attn.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + def run_kernel_only(): + kernel( + Q, + K_cache, + V_cache, + block_indices, + cache_seqlens, + block_table, + glse, + output_partial, + ) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="batch size") diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index 8b5cde38d..6b9862f80 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -7,6 +7,7 @@ import time import math from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench def flashattn(batch, heads, heads_kv, dim, dim_v): @@ -421,6 +422,58 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 print("sparse time: ", (time.time() - start) / 100 * 1000) +def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + dtype = torch.float16 + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") + + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() + if max_valid_block > 0: + for h in range(heads_kv): + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices + + block_indices, _ = block_indices.sort(dim=-1, descending=True) + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_size + max_selected_blocks = block_indices.shape[-1] + + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = sparse_kernel.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = sparse_kernel.kernel + + def run_kernel_only(): + kernel(Q, K, V, block_indices, cache_seqlens, glse, output_partial) + + return do_bench(run_kernel_only, warmup=100, rep=1000, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="batch size") diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index 0d759211a..8e66de4ac 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -5,10 +5,10 @@ import tilelang.language as T from einops import rearrange, einsum import argparse - import time import math from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench def flashattn(batch, heads, heads_kv, dim, dim_v): @@ -406,6 +406,63 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 print("sparse time: ", (time.time() - start) / 100 * 1000) +def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + dtype = torch.float16 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + random_index = torch.randint(0, batch, (1,), device="cuda").item() + cache_seqlens[random_index] = max_cache_seqlen + + num_blocks = (max_cache_seqlen + block_size - 1) // block_size + + valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int() + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") + + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() + valid_num_block = valid_num_blocks[b].item() + if valid_num_block > 0: + for h in range(heads_kv): + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] + block_mask[b, h, perm] = True + + model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = model.batch + heads = model.heads + heads_kv = model.heads_kv + dim_v = model.dim_v + dim = model.dim + block_size = model.block_size + block_H = model.block_H + max_cache_seqlen = K.shape[1] + max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = model.num_sm + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = model.kernel + + def run_kernel_only(): + kernel(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="batch size") diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py index b61d52fa0..01695742b 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -5,10 +5,10 @@ import argparse from einops import rearrange, einsum import torch.nn.functional as F - import math import time from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench @triton.autotune( diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py index c05b37779..232bcacaf 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py @@ -4,7 +4,6 @@ import argparse from einops import rearrange, einsum import torch.nn.functional as F - import math import time from heuristic import num_splits_heuristic diff --git a/examples/blocksparse_attention/regression_example_blocksparse_attention.py b/examples/blocksparse_attention/regression_example_blocksparse_attention.py new file mode 100644 index 000000000..477df0b12 --- /dev/null +++ b/examples/blocksparse_attention/regression_example_blocksparse_attention.py @@ -0,0 +1,24 @@ +import tilelang.testing +import example_tilelang_block_sparse_attn +import example_tilelang_sparse_gqa_decode_varlen_indice +import example_tilelang_sparse_gqa_decode_varlen_mask + + +def regression_example_tilelang_block_sparse_attn(): + tilelang.testing.process_func(example_tilelang_block_sparse_attn.run_regression_perf) + + +def regression_example_tilelang_sparse_gqa_decode_varlen_indice(): + tilelang.testing.process_func( + example_tilelang_sparse_gqa_decode_varlen_indice.run_regression_perf, batch=1, max_cache_seqlen=2048 + ) + + +def regression_example_tilelang_sparse_gqa_decode_varlen_mask(): + tilelang.testing.process_func( + example_tilelang_sparse_gqa_decode_varlen_mask.run_regression_perf, batch=1, max_cache_seqlen=2048 + ) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 0cbef5e0c..bf58ba47a 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -6,6 +6,7 @@ from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType import torch from typing import List +from tilelang.profiler import do_bench DEFAULT_BLOCK_M = 128 DEFAULT_BLOCK_N = 128 @@ -175,5 +176,32 @@ def main(): print(e) +def run_regression_perf(): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + kernel = blocksparse_matmul( + M, + N, + K, + block_M=DEFAULT_BLOCK_M, + block_N=DEFAULT_BLOCK_N, + block_K=DEFAULT_BLOCK_K, + num_stages=DEFAULT_NUM_STAGES, + thread_num=DEFAULT_THREAD_NUM, + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + ) + block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + + def run_kernel_only(): + kernel(a, b, block_mask) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/blocksparse_gemm/regression_example_blocksparse_gemm.py b/examples/blocksparse_gemm/regression_example_blocksparse_gemm.py new file mode 100644 index 000000000..81900a00c --- /dev/null +++ b/examples/blocksparse_gemm/regression_example_blocksparse_gemm.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_blocksparse_gemm + + +def regression_example_blocksparse_gemm(): + tilelang.testing.process_func(example_blocksparse_gemm.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index ec15b292e..48370dea7 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -205,5 +205,35 @@ def run_torch(): print("Torch: {:.2f} ms".format(latency)) +def run_regression_perf(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [2048, 6144] + if dtype == "float": + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + elif dtype == "float16": + x = torch.randn(M, N, device="cuda", dtype=torch.float16) + elif dtype == "bfloat16": + x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32) + M_max = int(ceil_div(batch_sizes.max(), 128) * 128) + + kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m) + + x_fp8, x_amax = kernel(x, batch_sizes) + x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes) + + torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01) + torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01) + + from tilelang.profiler import do_bench + + def run_tilelang(): + kernel(x, batch_sizes) + + return do_bench(run_tilelang, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py index 45281ab14..c0f31ebc4 100644 --- a/examples/cast/example_per_token_cast_to_fp8.py +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -109,5 +109,16 @@ def run_triton(): print("Triton: {:.2f} ms".format(latency)) +def run_regression_perf(M=8192, N=8192, blk_m=8): + kernel = per_token_cast_to_fp8(M, N, blk_m) + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(x) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/cast/regression_example_cast.py b/examples/cast/regression_example_cast.py new file mode 100644 index 000000000..4bdfb99e7 --- /dev/null +++ b/examples/cast/regression_example_cast.py @@ -0,0 +1,17 @@ +import tilelang.testing +import example_group_per_split_token_cast_to_fp8 +import example_per_token_cast_to_fp8 + + +def regression_example_group_per_split_token_cast_to_fp8(): + tilelang.testing.process_func( + example_group_per_split_token_cast_to_fp8.run_regression_perf, M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896] + ) + + +def regression_example_per_token_cast_to_fp8(): + tilelang.testing.process_func(example_per_token_cast_to_fp8.run_regression_perf, M=2048, N=512, blk_m=8) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index a84e5878a..125069848 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -107,5 +107,30 @@ def main(argv=None): print("All checks passed.✅") +def run_regression_perf(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + + args = parser.parse_args(argv) + N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p + + block_m = 64 + block_n = 128 + block_k = 32 + num_stages = 3 + threads = 256 + kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 600b608a3..9a156a020 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -160,6 +160,26 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True, +): + N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p + config = get_heuristic_config() + kernel = convolution(N, C, H, W, F, K, S, D, P, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser.add_argument("--n", type=int, default=128, help="n") diff --git a/examples/convolution/regression_example_convolution.py b/examples/convolution/regression_example_convolution.py new file mode 100644 index 000000000..0c80862ac --- /dev/null +++ b/examples/convolution/regression_example_convolution.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_convolution +import example_convolution_autotune + + +def regression_example_convolution(): + tilelang.testing.process_func(example_convolution.run_regression_perf) + + +def regression_example_convolution_autotune(): + tilelang.testing.process_func(example_convolution_autotune.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 733ae3c46..5cbb105bf 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -288,6 +288,25 @@ def main( print(f"TFlops: {total_flops / latency * 1e-9} TFlops") +def run_regression_perf( + batch=1, + heads=128, + kv_heads=1, + kv_ctx=8192, + dim=512, + pe_dim=64, +): + BLOCK_N = 64 + BLOCK_H = min(64, heads // kv_heads) + num_split = 1 + softmax_scale = (dim + pe_dim) ** -0.5 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=132, help="batch size") diff --git a/examples/deepseek_mla/regression_example_mla_decode.py b/examples/deepseek_mla/regression_example_mla_decode.py new file mode 100644 index 000000000..64e1c436a --- /dev/null +++ b/examples/deepseek_mla/regression_example_mla_decode.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_mla_decode + + +def regression_example_mla_decode(): + tilelang.testing.process_func(example_mla_decode.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index 38fc51a9f..48b1589d5 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -172,5 +172,38 @@ def main(): torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) +def run_regression_perf(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 + groups = HQ // H + SEQ_LEN_Q = 1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + ) + + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN_Q): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, block_indices.to(torch.int32)) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index a8dd26b63..99ffa427b 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -171,5 +171,43 @@ def main(): torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) +def run_regression_perf(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + scale=scale, + ) + torch.random.manual_seed(0) + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() + block_indices = block_indices.sort(-1)[0] + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, block_indices.to(torch.int32)) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/deepseek_nsa/regression_example_tilelang_nsa.py b/examples/deepseek_nsa/regression_example_tilelang_nsa.py new file mode 100644 index 000000000..cef71354a --- /dev/null +++ b/examples/deepseek_nsa/regression_example_tilelang_nsa.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_tilelang_nsa_fwd +import example_tilelang_nsa_decode + + +def regression_example_tilelang_nsa_fwd(): + tilelang.testing.process_func(example_tilelang_nsa_fwd.run_regression_perf) + + +def regression_example_tilelang_nsa_fwd_decode(): + tilelang.testing.process_func(example_tilelang_nsa_decode.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index 305e2afc4..6543405d6 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -280,5 +280,35 @@ def logits_fn(): print(f"cost_ref: {cost_ref}") +def run_regression_perf(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): + torch.manual_seed(0) + q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + weights = torch.randn(S, H, device="cuda", dtype=torch.float32) + p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) + + ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + + logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) + + logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) + + from tilelang.profiler import do_bench + + def logits_fn(): + return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + logits_fn() + + print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50)) + + return do_bench(logits_fn, warmup=100, rep=100, backend="cupti") + + if __name__ == "__main__": test_fp8_lighting_indexer() diff --git a/examples/deepseek_v32/regression_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/regression_tilelang_example_deepseek_v32.py new file mode 100644 index 000000000..97fc121a7 --- /dev/null +++ b/examples/deepseek_v32/regression_tilelang_example_deepseek_v32.py @@ -0,0 +1,31 @@ +import tilelang.testing +import fp8_lighting_indexer +import sparse_mla_bwd +import sparse_mla_fwd +import sparse_mla_fwd_pipelined +import topk_selector + + +def regression_topk_selector(): + tilelang.testing.process_func(topk_selector.run_regression_perf) + + +def regression_fp8_lighting_indexer(): + tilelang.testing.process_func(fp8_lighting_indexer.run_regression_perf, S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) + + +def regression_sparse_mla_fwd(): + tilelang.testing.process_func(sparse_mla_fwd.run_regression_perf, S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256) + +def regression_sparse_mla_fwd_pipelined(): + tilelang.testing.process_func( + sparse_mla_fwd_pipelined.run_regression_perf, S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256 + ) + + +def regression_sparse_mla_bwd(): + tilelang.testing.process_func(sparse_mla_bwd.run_regression_perf, S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index 1266e70ed..436fe55fc 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -337,5 +337,40 @@ def fn(): print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) +def run_regression_perf(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda") + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + from sparse_mla_fwd import sparse_mla_fwd_interface + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) + B, S, H, dim_plus_tail_dim = q.shape + _, S_kv, kv_group, _ = kv.shape + D = 512 + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + preprocess_kernel = preprocess(B, S, H, D) + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, None, True) + delta = preprocess_kernel(tl_out, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + return bwd_kernel(q, kv, do, indices, tl_lse, delta, dkv) + + return do_bench(run_kernel_only, rep=1000, warmup=250, backend="cupti") + + if __name__ == "__main__": test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index 3b963c751..f54c6cbd5 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -278,6 +278,40 @@ def fn(): print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) +def run_regression_perf( + B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, block_I=64, num_stages=2, threads=256 +): + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + is_casual = True + _, _, heads, dim_plus_tail_dim = q.shape + _, _, kv_group, _ = kv.shape + dim = 512 + tail_dim = dim_plus_tail_dim - dim + _, _, _, topk = indices.shape + kernel = sparse_mla_fwd(heads, dim, tail_dim, topk, kv_group, None, is_casual, block_I=block_I, num_stages=num_stages, threads=threads) + + def run_kernel_only(): + kernel(q, kv, indices) + + from tilelang.profiler import do_bench + + return do_bench( + run_kernel_only, + rep=100, + warmup=250, + ) + + if __name__ == "__main__": test_sparse_mla_fwd( B=1, diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 972160c99..b731ba912 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -427,6 +427,41 @@ def fn(): print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) +def run_regression_perf(B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024): + KV_stride = 1 + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + dim = 512 + tail_dim = dim_plus_tail_dim - dim + CP0 = q_start_s_index == 0 + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, KV_stride, kv_group, None, True, CP0) + + def run_kernel_only(): + kernel(q, kv, indices, torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda")) + + from tilelang.profiler import do_bench + + return do_bench( + run_kernel_only, + rep=100, + warmup=10, + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--test_correctness", action="store_true") diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py index cf87f526d..53e662e06 100644 --- a/examples/deepseek_v32/topk_selector.py +++ b/examples/deepseek_v32/topk_selector.py @@ -240,5 +240,35 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") +def run_regression_perf(batch=64, seq_len=32 * 1024, topk=2048): + batch = 64 + seq_len = 32 * 1024 + topk = 2048 + torch.manual_seed(1) + input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() + starts = torch.zeros(batch, dtype=torch.int32).cuda() + ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len + + indexes = tl_topk(input, starts, ends, topk) + + indexes_ref = torch.topk(input, topk, dim=-1)[1] + + for i in range(batch): + ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() + trt_np = indexes[i].cpu().to(torch.int32).numpy() + + set_ref = set(ref_np) + set_trt = set(trt_np) + intersection = set_ref & set_trt + print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + tl_topk(input, starts, ends, topk) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": test_topk_selector() diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py index ba3e0b4a7..8c9135633 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -438,6 +438,27 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(m=256, n=256, k=256, fast_dequant=True): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + fast_dequant=fast_dequant, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": main(256, 256, 256, True) main(256, 256, 256, False) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py index 1091306c6..3878025a6 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -538,6 +538,29 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": M, N, K = 256, 256, 256 scale_size = 32 diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py index 12395df0a..90b573187 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py @@ -554,6 +554,29 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": M, N, K = 256, 256, 256 scale_size = 32 diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index 352637de5..e80b2c0b8 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -273,6 +273,14 @@ def main(m=256, n=256, k=256, tune=False): print(f"Best config: {best_config}") +def run_regression_perf(m=256, n=256, k=256): + kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=False)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + return profiler.do_bench(warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--m", type=int, default=256, help="M") diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py index 3ff726738..1ce4c2aca 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -185,6 +185,14 @@ def main(m=128, n=256, k=256, tune=False): print(f"Best tflops: {total_flops / best_latency * 1e-9}") +def run_regression_perf(m=128, n=256, k=256): + kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=False)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 + ) + profiler = kernel.get_profiler() + return profiler.do_bench(warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--m", type=int, default=512, help="Matrix dimension M") diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py index 3f1214670..afb557d0c 100644 --- a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -217,5 +217,62 @@ def main() -> None: torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1) +def run_regression_perf(): + M = 1 + N = 1024 + K = 1024 + in_dtype = "float16" + out_dtype = "float16" + accum_dtype = "float16" + num_bits = 4 + storage_dtype = "int8" + source_format = "uint" + n_partition = 4 + reduce_thread = 32 + fast_decoding = True + trans_A = False + trans_B = True + group_size = -1 + with_scaling = False + + kernel = dequantize_gemv( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + num_bits, + storage_dtype, + source_format, + n_partition, + reduce_thread, + fast_decoding, + trans_A, + trans_B, + group_size, + with_scaling, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() + + if fast_decoding: + from tilelang.quantize.utils import interleave_weight + + qB = interleave_weight(qB, num_bits, in_dtype) + kernel(A, qB, C) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(A, qB, C) + + return do_bench(run_kernel_only, warmup=100, rep=1000, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py index 098f814c2..dd547538e 100644 --- a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -508,6 +508,61 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi print("All checks pass. ✅") +def run_regression_perf(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): + block_M, block_N, block_K = 128, 256, 128 + num_stages = 1 + threads = 512 + split = 1 + num_bits = 4 + num_elems_per_byte = 8 // num_bits + qk = k // num_elems_per_byte + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) + + if tune: + with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + "bfloat16", + "bfloat16", + "float32", + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + else: + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + "bfloat16", + "bfloat16", + "float32", + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, + ) + + return tilelang.profiler.do_bench( + lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100, backend="cupti" + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm diff --git a/examples/dequantize_gemm/regression_example_dequantize_gemm.py b/examples/dequantize_gemm/regression_example_dequantize_gemm.py new file mode 100644 index 000000000..b59bd5d52 --- /dev/null +++ b/examples/dequantize_gemm/regression_example_dequantize_gemm.py @@ -0,0 +1,37 @@ +import tilelang.testing +import example_dequant_gemm_bf16_fp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper_tma +import example_dequant_gemm_fp4_hopper +import example_dequant_gemm_w4a8 +import example_dequant_gemv_fp16xint4 +import example_dequant_groupedgemm_bf16_mxfp4_hopper + + +def regression_example_dequant_gemv_fp16xint4(): + tilelang.testing.process_func(example_dequant_gemv_fp16xint4.run_regression_perf) + + +def regression_example_dequant_gemm_fp4_hopper(): + tilelang.testing.process_func(example_dequant_gemm_fp4_hopper.run_regression_perf) + + +def regression_example_dequant_gemm_bf16_fp4_hopper(): + tilelang.testing.process_func(example_dequant_gemm_bf16_fp4_hopper.run_regression_perf) + +def regression_example_dequant_gemm_bf16_mxfp4_hopper(): + tilelang.testing.process_func(example_dequant_gemm_bf16_mxfp4_hopper.run_regression_perf) + +def regression_example_dequant_gemm_bf16_mxfp4_hopper_tma(): + tilelang.testing.process_func(example_dequant_gemm_bf16_mxfp4_hopper_tma.run_regression_perf) + + +def regression_example_dequant_groupedgemm_bf16_mxfp4_hopper(): + tilelang.testing.process_func(example_dequant_groupedgemm_bf16_mxfp4_hopper.run_regression_perf) + + +def regression_example_dequant_gemm_w4a8(): + tilelang.testing.process_func(example_dequant_gemm_w4a8.run_regression_perf) + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/dynamic_shape/example_dynamic.py b/examples/dynamic_shape/example_dynamic.py index 97ce7d9b3..88a53e59d 100644 --- a/examples/dynamic_shape/example_dynamic.py +++ b/examples/dynamic_shape/example_dynamic.py @@ -105,5 +105,28 @@ def main(M=16384, N=16384, K=16384): matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) +def run_regression_perf(M, N, K): + block_M, block_N, block_K = 128, 128, 32 + trans_A, trans_B = False, False + in_dtype, out_dtype = "float16", "float16" + accum_dtype = "float32" + num_stages = 3 + threads = 128 + kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) + import torch + + if trans_A: + A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + if trans_B: + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + else: + B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(input_tensors=[A, B, C], backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/dynamic_shape/regression_example_dynamic.py b/examples/dynamic_shape/regression_example_dynamic.py new file mode 100644 index 000000000..3e1603a22 --- /dev/null +++ b/examples/dynamic_shape/regression_example_dynamic.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_dynamic + + +def regression_example_dynamic(): + tilelang.testing.process_func(example_dynamic.run_regression_perf, M=1024, N=1024, K=1024) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index 464312ced..8f5b34bf4 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -4,6 +4,7 @@ import tilelang import tilelang.language as T from tilelang.autotuner import AutoTuner +from tilelang.profiler import do_bench def ref_program(x, y): @@ -80,5 +81,22 @@ def main(): torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) +def run_regression_perf(): + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=1024) + parser.add_argument("--n", type=int, default=1024) + args, _ = parser.parse_known_args() + M, N = args.m, args.n + a = torch.randn(M, N, dtype=torch.float32, device="cuda") + b = torch.randn(M, N, dtype=torch.float32, device="cuda") + config = {"block_M": 32, "block_N": 32, "threads": 128} + kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") + + def run_kernel_only(): + kernel(a, b) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/elementwise/regression_example_elementwise.py b/examples/elementwise/regression_example_elementwise.py new file mode 100644 index 000000000..de4a082b2 --- /dev/null +++ b/examples/elementwise/regression_example_elementwise.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_elementwise_add + + +def regression_example_elementwise_add(): + tilelang.testing.process_func(example_elementwise_add.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index d1f5843e3..7c950a57e 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -489,6 +489,50 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + head_kv = H // groups + Q = torch.randn(BATCH, N_CTX, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + kernel = flashattn_bwd_split( + BATCH, + H, + N_CTX, + D_HEAD_QK, + D_HEAD_V, + causal, + block_M=128, + block_N=32, + threads=256, + num_stages=2, + groups=groups, + ) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(groups, BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.float16) + dV = torch.zeros(groups, BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="Batch size") diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index 3501df1d7..d9ab6666d 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -700,6 +700,58 @@ def run1(): ) +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + total_q = BATCH * N_CTX + total_kv = BATCH * N_CTX + head_kv = H // groups + Q = torch.randn(total_q, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(total_kv, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(total_kv, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(total_q, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(total_q, H, D_HEAD_V, device=device, dtype=torch.half) + cu_seqlens_q = torch.arange(0, (BATCH + 1) * N_CTX, N_CTX, device=device, dtype=torch.int32) + cu_seqlens_k = cu_seqlens_q + max_seqlen_q = N_CTX + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, max_seqlen_q, D_HEAD_V) + kernel = flashattn_bwd_split( + BATCH, + total_q, + total_kv, + N_CTX, + H, + max_seqlen_q, + D_HEAD_QK, + D_HEAD_V, + causal, + block_M=128, + block_N=32, + threads=256, + num_stages=2, + groups=groups, + ) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(groups, total_kv, head_kv, D_HEAD_QK, device=device, dtype=torch.float16) + dV = torch.zeros(groups, total_kv, head_kv, D_HEAD_V, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO, cu_seqlens_q) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, cu_seqlens_q, cu_seqlens_k, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": arch = nvcc.get_target_compute_version() print(f"Detected GPU compute capability: {arch}") diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index adb7e06a8..4b6ff4218 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -339,6 +339,40 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + head_kv = H // groups + Q = torch.randn(BATCH, N_CTX, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + kernel = flashattn_bwd( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M=128, block_N=32, threads=256, num_stages=2, groups=groups + ) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.float32) + dV = torch.zeros(BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.float32) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="Batch size") diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index 408d6e507..9f79dd4fc 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -243,6 +243,19 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=1, help="batch size") diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 3492be764..8b4d59eb5 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -230,6 +230,26 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, + heads: int = 64, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + groups: int = 16, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=1, help="batch size") diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index 81eb6d1e5..d9051cede 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -352,6 +352,37 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 16 + N_CTX = 512 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(0) + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float16) + dV = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="Batch size") diff --git a/examples/flash_attention/example_mha_bwd_bshd.py b/examples/flash_attention/example_mha_bwd_bshd.py index 427a0f694..b8fb134b6 100644 --- a/examples/flash_attention/example_mha_bwd_bshd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -343,6 +343,37 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 16 + N_CTX = 512 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(42) + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float16) + dV = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=100, rep=1000, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="Batch size") diff --git a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index 813f379ca..f02837b05 100644 --- a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -320,6 +320,38 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(0) + block_M = 128 + block_N = 128 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros_like(Q, dtype=torch.float16) + dV = torch.zeros_like(Q, dtype=torch.float16) + Delta = mod_prep(O, dO) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="Batch size") diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index 7fa5549d0..77dd8cd9c 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -207,6 +207,25 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 64, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) + profiler = kernel.get_profiler() + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=1, help="batch size") diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index 440a2cd74..27dcbaade 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -211,6 +211,24 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + profiler = kernel.get_profiler() + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="batch size") diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index 888914c9b..7c7fe8667 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -193,6 +193,17 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf(batch: int = 8, heads: int = 32, seq_len: int = 4096, dim: int = 128, is_causal: bool = False): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) + profiler = kernel.get_profiler() + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="batch size") diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index b54d3e626..078cda7c1 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -199,6 +199,12 @@ def main( print(f"Ref latency: {ref_latency}") +def run_regression_perf(batch: int = 8, heads: int = 32, seq_len: int = 4096, dim: int = 128, is_causal: bool = False): + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="batch size") diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index f7bb36f71..6920fab3f 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -277,6 +277,49 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): print("All checks passed.✅") +def run_regression_perf(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + tilelang.testing.set_random_seed(0) + causal = False + if causal: + total_flops *= 0.5 + dtype = torch.float16 + device = torch.device("cuda") + window_size = (-1, -1) + q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + UQ = q_unpad.shape[0] + UK = k_unpad.shape[0] + UKV = k_unpad.shape[0] + kernel = flashattn(batch, UQ, UKV, heads, dim, causal) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=8, help="batch size") diff --git a/examples/flash_attention/regression_example_flash_attention.py b/examples/flash_attention/regression_example_flash_attention.py new file mode 100644 index 000000000..caf7c8670 --- /dev/null +++ b/examples/flash_attention/regression_example_flash_attention.py @@ -0,0 +1,72 @@ +import tilelang.testing +import example_gqa_fwd_bshd +import example_gqa_fwd_bshd_wgmma_pipelined +import example_mha_fwd_bhsd +import example_mha_fwd_bhsd_wgmma_pipelined +import example_mha_fwd_bshd +import example_mha_fwd_bshd_wgmma_pipelined +import example_mha_fwd_varlen +import example_gqa_bwd_tma_reduce_varlen +import example_gqa_bwd +import example_gqa_bwd_wgmma_pipelined +import example_mha_bwd_bshd +import example_mha_bwd_bhsd +import example_mha_bwd_bshd_wgmma_pipelined + + +def regression_example_gqa_bwd_tma_reduce_varlen(): + tilelang.testing.process_func(example_gqa_bwd_tma_reduce_varlen.run_regression_perf, name="example_gqa_bwd_tma_reduce_varlen") + + +def regression_example_gqa_bwd(): + tilelang.testing.process_func(example_gqa_bwd.run_regression_perf, name="example_gqa_bwd") + + +def regression_example_gqa_bwd_wgmma_pipelined(): + tilelang.testing.process_func(example_gqa_bwd_wgmma_pipelined.run_regression_perf, name="example_gqa_bwd_wgmma_pipelined") + +def regression_example_mha_bwd_bshd(): + tilelang.testing.process_func(example_mha_bwd_bshd.run_regression_perf, name="example_mha_bwd_bshd") + +def regression_example_mha_bwd_bhsd(): + tilelang.testing.process_func(example_mha_bwd_bhsd.run_regression_perf, name="example_mha_bwd_bhsd") + + +def regression_example_mha_bwd_bshd_wgmma_pipelined(): + tilelang.testing.process_func( + example_mha_bwd_bshd_wgmma_pipelined.run_regression_perf, name="example_mha_bwd_bshd_wgmma_pipelined" + ) + + +def regression_example_gqa_fwd_bshd_wgmma_pipelined(): + tilelang.testing.process_func( + example_gqa_fwd_bshd_wgmma_pipelined.run_regression_perf, batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16 + ) + + +def regression_example_gqa_fwd_bshd(): + tilelang.testing.process_func( + example_gqa_fwd_bshd.run_regression_perf, batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16 + ) + + +def regression_example_mha_fwd_bhsd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_fwd_bhsd_wgmma_pipelined.run_regression_perf) + + +def regression_example_mha_fwd_bhsd(): + tilelang.testing.process_func(example_mha_fwd_bhsd.run_regression_perf) + + +def regression_example_mha_fwd_bshd_wgmma_pipelined(): + tilelang.testing.process_func(example_mha_fwd_bshd_wgmma_pipelined.run_regression_perf, batch=1, heads=32, seq_len=256) + +def regression_example_mha_fwd_bshd(): + tilelang.testing.process_func(example_mha_fwd_bshd.run_regression_perf, batch=1, seq_len=256) + +def regression_example_mha_fwd_varlen(): + tilelang.testing.process_func(example_mha_fwd_varlen.run_regression_perf, batch=4, heads=16, seq_len=512, dim=64) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 136a51292..193a4f353 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -483,6 +483,14 @@ def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192 print(f"Ref latency: {ref_latency}") +def run_regression_perf(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128): + batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim + config, _ = get_heuristic_config() + kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=1, help="batch size") diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index d0381bc4a..d958373fd 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -318,5 +318,13 @@ def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) +def run_regression_perf(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): + BLOCK_M = 128 + BLOCK_N = 64 + kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(n_warmup=10, n_repeat=10, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/flash_decoding/regression_example_flash_decoding.py b/examples/flash_decoding/regression_example_flash_decoding.py new file mode 100644 index 000000000..aed934d80 --- /dev/null +++ b/examples/flash_decoding/regression_example_flash_decoding.py @@ -0,0 +1,17 @@ +import tilelang.testing +import example_gqa_decode +import example_mha_inference + + +def regression_example_gqa_decode(): + tilelang.testing.process_func(example_gqa_decode.run_regression_perf) + + +def regression_example_mha_inference(): + tilelang.testing.process_func( + example_mha_inference.run_regression_perf, BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False + ) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index b737f30aa..ce2611240 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -520,5 +520,121 @@ def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n print("✅ Tilelang and Torch match") +def run_regression_perf( + d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192 +): + config = { + "dhidden": d_hidden, + "dexpert": d_expert, + "nroutedexperts": n_routed_experts, + "nsharedexperts": n_shared_experts, + "nexpertspertoken": n_experts_per_token, + "bs": batch_size, + "seqlen": seq_len, + "seed": 81394, + } + from tilelang.profiler import do_bench + + data = generate_input(**config) + + x, weights, config = data + + dtype_str = "float16" + + shared_kernel = moe_forward_tilelang_shared( + config["d_hidden"], + config["d_expert"], + config["n_shared_experts"], + dtype=dtype_str, + num_tokens=config["batch_size"] * config["seq_len"], + ) + routed_kernel = moe_forward_tilelang_routed( + config["d_hidden"], + config["d_expert"], + config["n_routed_experts"], + dtype=dtype_str, + group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], + group_count=config["n_routed_experts"], + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + k_pack=1, + coalesced_width=2, + ) + + moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) + batch_size, seq_len, hidden_dim = x.shape + expert_indices, expert_scores = moe.gating_network(x) + flat_expert_indices = expert_indices.view(-1) + flat_expert_weights = expert_scores.view(-1) + x_flat = x.view(-1, hidden_dim) + idxs = flat_expert_indices.argsort() + counts = flat_expert_indices.bincount().cpu().numpy() + tokens_per_expert = counts.cumsum() + num_per_tok = moe.config["n_experts_per_token"] + token_idxs = idxs // num_per_tok + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = x_flat[exp_token_idxs] + moe.stacked_expert_tokens[start_idx:end_idx] = expert_tokens + moe.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs + moe.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]] + group_sizes = torch.tensor(counts, dtype=torch.int32, device=moe.device) + group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=moe.device) + group_padded_offsets = [0 for _ in range(len(group_sizes))] + for i in range(1, len(group_sizes)): + group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / moe.padding_M) * moe.padding_M + block_token = 128 + M = ( + math.ceil(moe.config["batch_size"] * moe.config["seq_len"] * moe.config["n_experts_per_token"] / block_token) + + moe.config["n_routed_experts"] + ) + group_idx_for_bx = [0 for _ in range(M)] + for bx in range(M): + m_start_padded = bx * block_token + for i in range(moe.config["n_routed_experts"]): + if m_start_padded >= group_padded_offsets[i]: + group_idx_for_bx[bx] = i + group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=moe.device) + group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=moe.device) + + def run_shared_kernel_only(): + moe.routed_kernel( + moe.stacked_expert_tokens, + moe.stacked_expert_w_gate, + moe.stacked_expert_w_up, + moe.stacked_expert_w_down, + moe.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + moe.up_logits_routed, + moe.expert_output_routed, + ) + + def run_routed_kernel_only(): + moe.routed_kernel( + moe.stacked_expert_tokens, + moe.stacked_expert_w_gate, + moe.stacked_expert_w_up, + moe.stacked_expert_w_down, + moe.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + moe.up_logits_routed, + moe.expert_output_routed, + ) + + return do_bench(run_routed_kernel_only, warmup=100, rep=1000, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/fusedmoe/regression_example_fusedmoe.py b/examples/fusedmoe/regression_example_fusedmoe.py new file mode 100644 index 000000000..a473cc901 --- /dev/null +++ b/examples/fusedmoe/regression_example_fusedmoe.py @@ -0,0 +1,19 @@ +import tilelang.testing +import example_fusedmoe_tilelang + + +def regression_example_fusedmoe_tilelang(): + tilelang.testing.process_func( + example_fusedmoe_tilelang.run_regression_perf, + d_hidden=1024, + d_expert=256, + n_routed_experts=8, + n_shared_experts=1, + n_experts_per_token=4, + batch_size=1, + seq_len=1024, + ) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index 2c234d122..827c23373 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -57,5 +57,11 @@ def main(): print(f"tilelang Latency: {latency}ms") +def run_regression_perf(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index badc33402..3e8b0633e 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -228,6 +228,13 @@ def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}") +def run_regression_perf(M: int = 4096, N: int = 4096, K: int = 4096): + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index 488e5bf6b..2e8ebb114 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -181,5 +181,12 @@ def main(M=4096, N=4096, K=4096): profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) +def run_regression_perf(M=4096, N=4096, K=4096): + in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler = kernel.get_profiler() + return profiler.do_bench(profiler.func, warmup=25, backend="cupti") + + if __name__ == "__main__": main(M=4096, N=4096, K=4096) diff --git a/examples/gemm/example_gemm_persistent.py b/examples/gemm/example_gemm_persistent.py index 6fc0e5aac..094bd7cf5 100644 --- a/examples/gemm/example_gemm_persistent.py +++ b/examples/gemm/example_gemm_persistent.py @@ -126,6 +126,17 @@ def main(M=4096, N=4096, K=4096): print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}") +def run_regression_perf(M=4096, N=4096, K=4096): + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 64 + threads = 256 + num_stages = 3 + persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + return persistent_profiler.do_bench(warmup=500, backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--M", type=int, default=8192, help="M dimension") diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py index d1eb11df5..885a8669e 100644 --- a/examples/gemm/example_gemm_schedule.py +++ b/examples/gemm/example_gemm_schedule.py @@ -64,5 +64,19 @@ def main(): print(kernel.get_kernel_source()) +def run_regression_perf(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm/regression_example_gemm.py b/examples/gemm/regression_example_gemm.py new file mode 100644 index 000000000..992d899f2 --- /dev/null +++ b/examples/gemm/regression_example_gemm.py @@ -0,0 +1,23 @@ +import tilelang.testing +import example_gemm +import example_gemm_autotune +import example_gemm_intrinsics +import example_gemm_schedule + + +def regression_example_gemm_autotune(): + tilelang.testing.process_func(example_gemm_autotune.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_gemm_intrinsics(): + tilelang.testing.process_func(example_gemm_intrinsics.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_gemm_schedule(): + tilelang.testing.process_func(example_gemm_schedule.run_regression_perf) + +def regression_example_gemm(): + tilelang.testing.process_func(example_gemm.run_regression_perf) + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index 1ecd344bc..3daa532b3 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -60,5 +60,18 @@ def main(): test_gemm_fp8(1024, 1024, 1024, "float8_e5m2") +def run_regression_perf(): + M, N, K = 1024, 1024, 1024 + dtype = "float8_e4m3" + kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e4m3 = profiler_e4m3.do_bench(warmup=25, backend="cupti") + dtype = "float8_e5m2" + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(warmup=25, backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 + + if __name__ == "__main__": main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index 3af4c3d6d..ab0afa435 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -78,5 +78,18 @@ def main(): test_gemm_fp8(1024, 1024, 8192, "float8_e5m2") +def run_regression_perf(): + M, N, K = 1024, 1024, 8192 + dtype = "float8_e4m3" + kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e4m3 = profiler_e4m3.do_bench(warmup=25, backend="cupti") + dtype = "float8_e5m2" + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(warmup=25, backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 + + if __name__ == "__main__": main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 6e2d41be8..81aafce36 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -224,5 +224,19 @@ def main(): assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") +def run_regression_perf(): + M, N, K = 128, 128, 128 + out_dtype, accum_dtype = "float32", "float32" + in_dtype = "float8_e4m3" + kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e4m3 = profiler_e4m3.do_bench(warmup=25, backend="cupti") + in_dtype = "float8_e5m2" + kernel_e5m2 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(warmup=25, backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 + + if __name__ == "__main__": main() diff --git a/examples/gemm_fp8/regression_example_gemm_fp8.py b/examples/gemm_fp8/regression_example_gemm_fp8.py new file mode 100644 index 000000000..8b711560f --- /dev/null +++ b/examples/gemm_fp8/regression_example_gemm_fp8.py @@ -0,0 +1,19 @@ +import tilelang.testing +import example_tilelang_gemm_fp8 +import example_tilelang_gemm_fp8_2xAcc +import example_tilelang_gemm_fp8_intrinsic + + +def regression_example_tilelang_gemm_fp8_2xAcc(): + tilelang.testing.process_func(example_tilelang_gemm_fp8_2xAcc.run_regression_perf) + + +def regression_example_tilelang_gemm_fp8_intrinsic(): + tilelang.testing.process_func(example_tilelang_gemm_fp8_intrinsic.run_regression_perf) + + +def regression_example_tilelang_gemm_fp8(): + tilelang.testing.process_func(example_tilelang_gemm_fp8.run_regression_perf) + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index 320a699c5..99c584d5c 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -56,5 +56,28 @@ def main(): torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) +def run_regression_perf(): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b, c) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py index dfd847101..5860a7fb2 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py @@ -55,5 +55,29 @@ def main(): torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) +def run_regression_perf(): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b, c) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm_splitk/regression_example_gemm_splitk.py b/examples/gemm_splitk/regression_example_gemm_splitk.py new file mode 100644 index 000000000..42fa192dd --- /dev/null +++ b/examples/gemm_splitk/regression_example_gemm_splitk.py @@ -0,0 +1,14 @@ +import tilelang.testing +import example_tilelang_gemm_splitk +import example_tilelang_gemm_splitk_vectorize_atomicadd + + +def regression_example_tilelang_gemm_splitk(): + tilelang.testing.process_func(example_tilelang_gemm_splitk.run_regression_perf) + + +def regression_example_tilelang_gemm_splitk_vectorize_atomicadd(): + tilelang.testing.process_func(example_tilelang_gemm_splitk_vectorize_atomicadd.run_regression_perf) + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index 2d83586e5..ecc545f93 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -199,5 +199,32 @@ def main(): torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) +def run_regression_perf(): + kernel = tl_matmul_streamk( + m, + n, + k, + streamk_tiles, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + False, + True, + "float16", + "float16", + "float32", + 2, + 64, + ) + b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) + kernel(A, B, b_c) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(A, B, b_c) + + return do_bench(run_kernel_only, warmup=10, rep=100, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm_streamk/regression_example_tilelang_gemm_splitk.py b/examples/gemm_streamk/regression_example_tilelang_gemm_splitk.py new file mode 100644 index 000000000..c2baa5533 --- /dev/null +++ b/examples/gemm_streamk/regression_example_tilelang_gemm_splitk.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_tilelang_gemm_streamk + + +def regression_example_tilelang_gemm_streamk(): + tilelang.testing.process_func(example_tilelang_gemm_streamk.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 00cbac067..350074f43 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -362,5 +362,23 @@ def main(do_bench: bool = True): print(f"TileLang BlockReduce Latency: {tilelang_tile_latency} ms\n") +def run_regression_perf(): + N, K = 1024, 1024 + latency = 0.0 + kernel_list = [ + naive_gemv(N, K, 128, 128), + naive_splitk_gemv(N, K, 32, 32), + splitk_gemv(N, K, 32, 32, 32), + splitk_gemv_vectorized(N, K, 2, 32), + splitk_gemv_vectorized_tvm(N, K, 2, 32), + gemv_alloc_reducer(N, K, block_M=128, block_N=128), + ] + for kernel in kernel_list: + profiler = kernel.get_profiler() + # Benchmark the TileLang kernel itself, not the PyTorch reference. + latency += profiler.do_bench(warmup=50, backend="cupti") + return latency / len(kernel_list) + + if __name__ == "__main__": main() diff --git a/examples/gemv/regression_example_gemv.py b/examples/gemv/regression_example_gemv.py new file mode 100644 index 000000000..59f92f0dc --- /dev/null +++ b/examples/gemv/regression_example_gemv.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_gemv + + +def regression_example_gemv(): + tilelang.testing.process_func(example_gemv.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index 7cbfc465a..fa21ad44e 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -192,6 +192,21 @@ def main(B=1, S=1024, H=16, D=128): print(f"Speedup: {t1 / t2:.3f}x") +def run_regression_perf(B=1, S=1024, H=16, D=128): + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + do = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + q = l2norm_fwd(q)[0].requires_grad_(True) + k = l2norm_fwd(k)[0].requires_grad_(True) + kernel = tl_fused_chunk_bwd_kernel(B, S, H, D, D) + dQ = torch.zeros_like(q, dtype=torch.float32) + dK = torch.zeros_like(k, dtype=torch.float32) + dV = torch.zeros_like(v, dtype=torch.float32) + kernel(q, k, v, do, dQ, dK, dV) + return do_bench(lambda: kernel(q, k, v, do, dQ, dK, dV), backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--B", type=int, default=8, help="Batch size") diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index 3d28f92b0..66e5d632a 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -138,6 +138,18 @@ def main(B=1, S=512, H=16, D=128): print(f"Speedup: {t1 / t2:.3f}x") +def run_regression_perf(B=1, S=512, H=16, D=128): + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + B, S, H, D = q.shape + kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) + o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32) + return do_bench(lambda: kernel(q, k, v, o), backend="cupti") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--B", type=int, default=8, help="Batch size") diff --git a/examples/linear_attention/regression_linear_attn.py b/examples/linear_attention/regression_linear_attn.py new file mode 100644 index 000000000..b51759576 --- /dev/null +++ b/examples/linear_attention/regression_linear_attn.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_linear_attn_bwd +import example_linear_attn_fwd + + +def regression_example_linear_attn_fwd(): + tilelang.testing.process_func(example_linear_attn_fwd.run_regression_perf) + + +def regression_example_linear_attn_bwd(): + tilelang.testing.process_func(example_linear_attn_bwd.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index 6600bb5ed..73656b974 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -619,5 +619,78 @@ def main(argv=None): print(f"speedup: {triton_time / tilelang_time:.2f}x") +def run_regression_perf(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--heads", type=int, default=1) + parser.add_argument("--seq_len", type=int, default=16384) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--vertical_size", type=int, default=1000) + parser.add_argument("--slash_size", type=int, default=200) + args = parser.parse_args(argv) + BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim + vertical_size, slash_size = args.vertical_size, args.slash_size + torch.manual_seed(0) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + q_len = SEQ_LEN + vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size) + last_q = 64 + qk = torch.einsum("bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k) + arange = torch.arange(last_q, device="cuda") + qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:], -torch.inf) + qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + vertical_topk = torch.topk(vertical, vertical_size, -1).indices + slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1] + slash[..., -30:] = torch.inf + slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices + block_size_M = 64 + block_size_N = 64 + query, key, value = q, k, v + v_idx, s_idx = vertical_topk, slash + batch_size, num_heads, context_size, head_dim = query.shape + v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + from torch.utils.cpp_extension import load + import os + + current_dir = os.path.dirname(os.path.abspath(__file__)) + sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")] + ops = load(name="convert", sources=sources, verbose=False) + convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes + batch_size, num_heads, context_size, head_dim = query.shape + pad = (block_size_M - context_size) & (block_size_M - 1) + if pad == block_size_M: + pad = 0 + query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) + key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) + value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim + query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device) + block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes( + seqlens, + v_idx, + s_idx, + context_size, + block_size_M, + block_size_N, + ) + tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, vertical_topk.shape[-1], slash.shape[-1]) + + def run_kernel_only(): + tl_kernel(query, key, value, block_count, block_offset, column_count, column_index) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/minference/regression_vs_sparse_attn.py b/examples/minference/regression_vs_sparse_attn.py new file mode 100644 index 000000000..6c2bf8317 --- /dev/null +++ b/examples/minference/regression_vs_sparse_attn.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_vertical_slash_sparse_attn + + +def regression_example_vertical_slash_sparse_attn(): + tilelang.testing.process_func(example_vertical_slash_sparse_attn.run_regression_perf, argv=[]) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py index f5f7fe7ba..414cee7db 100644 --- a/examples/seer_attention/block_sparse_attn_tilelang.py +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -250,5 +250,56 @@ def main(): test_topk_sparse_attention_qlen_lt_klen() +def run_regression_perf(): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 4, 2, 256, 64 + TOPK = 2 + BLOCK = 64 + torch.manual_seed(0) + + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(q, k, v, block_mask.to(torch.int8)) + + latency_1 = do_bench(run_kernel_only, backend="cupti") + + BATCH, N_HEADS = 1, 1 + Q_LEN, K_LEN, D_HEAD = 128, 256, 64 + TOPK = 1 + BLOCK = 64 + torch.manual_seed(0) + + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + downsample_factor = BLOCK + downsample_len = math.ceil(K_LEN / downsample_factor) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) + print(kernel.get_kernel_source()) + + def run_kernel_only2(): + kernel(q, k, v, block_mask.to(torch.int8)) + + latency_2 = do_bench(run_kernel_only2, backend="cupti") + + return (latency_1 + latency_2) / 2 + + if __name__ == "__main__": main() diff --git a/examples/seer_attention/regression_block_sparse_attn_tilelang.py b/examples/seer_attention/regression_block_sparse_attn_tilelang.py new file mode 100644 index 000000000..3b6cf8701 --- /dev/null +++ b/examples/seer_attention/regression_block_sparse_attn_tilelang.py @@ -0,0 +1,10 @@ +import tilelang.testing +import block_sparse_attn_tilelang + + +def regression_block_sparse_attn_tilelang(): + tilelang.testing.process_func(block_sparse_attn_tilelang.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/sparse_tensorcore/regression_example_sparse_tensorcore.py b/examples/sparse_tensorcore/regression_example_sparse_tensorcore.py new file mode 100644 index 000000000..dfaa8b18e --- /dev/null +++ b/examples/sparse_tensorcore/regression_example_sparse_tensorcore.py @@ -0,0 +1,11 @@ +import tilelang.testing +import tilelang +import tilelang_example_sparse_tensorcore + + +def regression_example_sparse_tensorcore(): + tilelang.testing.process_func(tilelang_example_sparse_tensorcore.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py index 6c37dc09c..5f5b3671d 100644 --- a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py +++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -114,5 +114,44 @@ def main(): run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128) +def run_regression_perf(): + M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, num_threads = ( + 512, + 1024, + 768, + 128, + 128, + 128, + "float16", + "float16", + "float32", + 2, + 128, + ) + kernel = matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + num_threads, + ) + A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda") + A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) + B = torch.randn((K, N), device="cuda", dtype=torch.float16) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(A_sparse, E, B) + + return do_bench(run_kernel_only, backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/topk/example_topk.py b/examples/topk/example_topk.py index c0cf09bc0..f4afceac9 100644 --- a/examples/topk/example_topk.py +++ b/examples/topk/example_topk.py @@ -89,5 +89,29 @@ def main(argv=None): print(f"Tilelang latency: {tilelang_latency}") +def run_regression_perf(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=320, help="num_tokens") + parser.add_argument("--N", type=int, default=128, help="num_experts") + parser.add_argument("--topk", type=int, default=6, help="topk") + parser.add_argument("--blk_m", type=int, default=64, help="blk_m") + # In benchmark mode, ignore process-wide sys.argv unless an explicit argv is provided. + args = parser.parse_args(argv or []) + M, N, topk, blk_m = args.M, args.N, args.topk, args.blk_m + + logits = torch.rand((M, N), device="cuda", dtype=torch.float32) + + kernel = tl_topk(M=M, N=N, topk=topk, blk_m=blk_m) + tl_gates, tl_indices = kernel(logits) + + torch_gates, torch_indices = ref_program(logits, topk) + + torch.testing.assert_close(tl_gates, torch_gates) + torch.testing.assert_close(tl_indices, torch_indices) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/topk/regression_topk_tilelang.py b/examples/topk/regression_topk_tilelang.py new file mode 100644 index 000000000..2c0dea1bd --- /dev/null +++ b/examples/topk/regression_topk_tilelang.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_topk + + +def regression_example_topk(): + tilelang.testing.process_func(example_topk.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py index 5d438b5de..719e2bb06 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -82,5 +82,15 @@ def main(M=16384, N=16384, K=16384): print(f"Latency: {latency} ms") +def run_regression_perf(M=16384, N=16384, K=16384): + tilelang.disable_cache() + block_M = 128 + block_N = 128 + block_K = 64 + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py index 03ddf8122..f7f597627 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py @@ -74,5 +74,27 @@ def main(M=1024, N=1024, K=1024): print(f"Latency: {latency} ms") +def run_regression_perf(M=1024, N=1024, K=1024): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py index 63aed2bed..7f6ba55e4 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -75,5 +75,28 @@ def main(M=16384, N=16384, K=16384): print(f"Latency: {latency} ms") +def run_regression_perf(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py index f3f8a665b..bc20ea742 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py @@ -78,5 +78,28 @@ def main(M=16384, N=16384, K=16384): print(f"Latency: {latency} ms") +def run_regression_perf(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/regression_example_warp_specialize.py b/examples/warp_specialize/regression_example_warp_specialize.py new file mode 100644 index 000000000..68b4d913e --- /dev/null +++ b/examples/warp_specialize/regression_example_warp_specialize.py @@ -0,0 +1,23 @@ +import tilelang.testing +import example_warp_specialize_gemm_barrierpipe_stage2 +import example_warp_specialize_gemm_copy_0_gemm_1 +import example_warp_specialize_gemm_copy_1_gemm_0 +import example_warp_specialize_gemm_softpipe_stage2 + + +def regression_example_warp_specialize_gemm_barrierpipe_stage2(): + tilelang.testing.process_func(example_warp_specialize_gemm_barrierpipe_stage2.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_warp_specialize_gemm_copy_0_gemm_1(): + tilelang.testing.process_func(example_warp_specialize_gemm_copy_0_gemm_1.run_regression_perf, M=1024, N=1024, K=1024) + +def regression_example_warp_specialize_gemm_copy_1_gemm_0(): + tilelang.testing.process_func(example_warp_specialize_gemm_copy_1_gemm_0.run_regression_perf, M=1024, N=1024, K=1024) + + +def regression_example_warp_specialize_gemm_softpipe_stage2(): + tilelang.testing.process_func(example_warp_specialize_gemm_softpipe_stage2.run_regression_perf, M=1024, N=1024, K=1024) + +if __name__ == "__main__": + tilelang.testing.regression() \ No newline at end of file diff --git a/maint/scripts/ci_performance.py b/maint/scripts/ci_performance.py deleted file mode 100644 index 8a353c0a9..000000000 --- a/maint/scripts/ci_performance.py +++ /dev/null @@ -1,42 +0,0 @@ -import subprocess -import re -from tabulate import tabulate - -import os - -env = os.environ.copy() -env["TILELANG_CLEAR_CACHE"] = "1" - - -def parse_output(output): - data = {} - for line in output.split("\n"): - line = line.strip() - if line.startswith("Latency:"): - match = re.search(r"Latency: ([\d.]+)", line) - data["latency"] = match.group(1) if match else "N/A" - elif line.startswith("TFlops:"): - match = re.search(r"TFlops: ([\d.]+)", line) - data["best_tflops"] = match.group(1) if match else "N/A" - elif line.startswith("Config:"): - data["config"] = line.split("Config: ")[-1] - elif line.startswith("Reference TFlops:"): - match = re.search(r"Reference TFlops: ([\d.]+)", line) - data["ref_tflops"] = match.group(1) if match else "N/A" - return data - - -output_v1 = subprocess.run(["./tl/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout -data_v1 = parse_output(output_v1) - -output_v2 = subprocess.run(["./tll/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout -data_v2 = parse_output(output_v2) - -table = [ - ["original", data_v1["latency"], data_v1["best_tflops"], data_v1["ref_tflops"], data_v1["config"]], - ["current", data_v2["latency"], data_v2["best_tflops"], data_v2["ref_tflops"], data_v2["config"]], -] - -headers = ["version", "Best Latency (s)", "Best TFlops", "Reference TFlops", "Best Config"] - -print(tabulate(table, headers=headers, tablefmt="github", stralign="left", numalign="decimal")) diff --git a/maint/scripts/performance.py b/maint/scripts/performance.py deleted file mode 100644 index 849bcf362..000000000 --- a/maint/scripts/performance.py +++ /dev/null @@ -1,95 +0,0 @@ -import argparse -import tilelang.language as T -from tilelang.autotuner import AutoTuner - - -def ref_program(A, B): - return A @ B.T - - -def get_configs(): - configs = [ - { - "block_M": 128, - "block_N": 128, - "block_K": 64, - "num_stages": 2, - "thread_num": 256, - "enable_rasteration": True, # keep param name for backward-compat - } - ] - return configs - - -def run(M, N, K): - def kernel( - block_M=None, - block_N=None, - block_K=None, - num_stages=None, - thread_num=None, - enable_rasteration=None, - ): - dtype = "float16" - accum_dtype = "float" - - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), dtype) - T.use_swizzle(panel_size=10, enable=enable_rasteration) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - ) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return main - - autotuner = ( - AutoTuner.from_kernel(kernel=kernel, configs=get_configs()) - .set_compile_args( - out_idx=[-1], - target="auto", - ) - .set_profile_args( - ref_prog=ref_program, - ) - ) - return autotuner.run(warmup=3, rep=20) - - -if __name__ == "__main__": - # Parse command-line arguments for matrix dimensions - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") - parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") - parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - args = parser.parse_args() - - M, N, K = args.m, args.n, args.k - - # Compute total floating-point operations to measure throughput - total_flops = 2 * M * N * K - - result = run(M, N, K) - - print(f"Latency: {result.latency}") - print(f"TFlops: {total_flops / result.latency * 1e-9:.3f}") - print(f"Config: {result.config}") - - print(f"Reference TFlops: {total_flops / result.ref_latency * 1e-9:.3f}") diff --git a/maint/scripts/test_perf_regression.py b/maint/scripts/test_perf_regression.py new file mode 100644 index 000000000..0ff930d5c --- /dev/null +++ b/maint/scripts/test_perf_regression.py @@ -0,0 +1,87 @@ +import subprocess +import re +import os +from tabulate import tabulate +import pandas as pd + +try: + import tilelang + + tilelang.disable_cache() +except Exception: + tilelang = None + +OLD_PYTHON = os.environ.get("OLD_PYTHON", "./old/bin/python") +NEW_PYTHON = os.environ.get("NEW_PYTHON", "./new/bin/python") +OUT_MD = os.environ.get("PERF_REGRESSION_MD", "regression_result.md") +OUT_PNG = os.environ.get("PERF_REGRESSION_PNG", "regression_result.png") + +def parse_output(output): + data = {} + for line in output.split("\n"): + line = line.strip() + m = re.match(r"\|\s*([^\|]+)\s*\|\s*([0-9\.]+)\s*\|", line) + if m is not None: + data[m.group(1)] = float(m.group(2)) + return data + + +def run_cmd(cmd): + p = subprocess.run(cmd, capture_output=True, text=True) + if p.returncode != 0: + raise RuntimeError(f"Command failed: {' '.join(cmd)}\nSTDOUT:\n{p.stdout}\nSTDERR:\n{p.stderr}") + return p.stdout + + +def draw(df): + import matplotlib.pyplot as plt + import seaborn as sns + + plt.figure(figsize=(max(len(df) * 2.2, 6), 20)) + sns.set_theme(style="whitegrid", font_scale=0.9) + top3_idx = df.nlargest(min(3, len(df)), "Speedup").index + bot3_idx = df.nsmallest(min(3, len(df)), "Speedup").index + label_idx = set(top3_idx.tolist() + bot3_idx.tolist()) + + for i, val in enumerate(df["Speedup"]): + if i in label_idx: + plt.text(i, val + 0.02, f"{val:.2f}x", ha="center", va="bottom", color="red", fontsize=8, fontweight="bold") + + plt.xticks(range(len(df)), df["File"], rotation=70, ha="right", fontsize=12) + plt.ylabel("Current Speedup vs Original", fontsize=14) + plt.title("Current Speedup vs Original", fontsize=14, fontweight="bold") + plt.ylim(0, max(df["Speedup"]) * 1.2) + sns.despine() + plt.tight_layout() + plt.savefig(OUT_PNG, dpi=300) + + +output_v1 = run_cmd([OLD_PYTHON, "-c", "import tilelang.testing.perf_regression as pr; pr.regression_all()"]) +output_v2 = run_cmd([NEW_PYTHON, "-c", "import tilelang.testing.perf_regression as pr; pr.regression_all()"]) + +data_v1 = parse_output(output_v1) +data_v2 = parse_output(output_v2) + +common_keys = sorted(set(data_v1) & set(data_v2)) +if not common_keys: + raise RuntimeError("No common entries between old and new versions") + +table = [] +for key in data_v1.keys(): + speedup = data_v1[key] / data_v2[key] + table.append([key, data_v1[key], data_v2[key], speedup]) + +if not table: + raise RuntimeError("All results are invalid (<= 0)") + +table.sort(key=lambda x: x[-1]) + +headers = ["File", "Original Latency", "Current Latency", "Speedup"] + +with open(OUT_MD, "w") as f: + f.write(tabulate(table, headers=headers, tablefmt="github", stralign="left", numalign="decimal")) + f.write("\n") + +df = pd.DataFrame(table, columns=headers) +df = df.sort_values("Speedup", ascending=False).reset_index(drop=True) +draw(df) diff --git a/requirements-test.txt b/requirements-test.txt index 38bdf2d7b..938b3d034 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -30,3 +30,5 @@ scipy tabulate tornado wheel +matplotlib +seaborn \ No newline at end of file diff --git a/testing/test_perf_regression_runner.py b/testing/test_perf_regression_runner.py new file mode 100644 index 000000000..46a74e41a --- /dev/null +++ b/testing/test_perf_regression_runner.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import importlib.util +import runpy +import sys +import types +from pathlib import Path + + +def _load_perf_module(monkeypatch): + """Load perf_regression directly while stubbing the tilelang package to avoid heavy test dependencies.""" + module_path = Path(__file__).resolve().parents[1] / "tilelang/testing/perf_regression.py" + spec = importlib.util.spec_from_file_location("tilelang.testing.perf_regression", module_path) + assert spec is not None and spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules["tilelang.testing.perf_regression"] = module + spec.loader.exec_module(module) + + tilelang_pkg = types.ModuleType("tilelang") + testing_pkg = types.ModuleType("tilelang.testing") + testing_pkg.process_func = module.process_func # type: ignore[attr-defined] + testing_pkg.regression = module.regression # type: ignore[attr-defined] + testing_pkg.perf_regression = module # type: ignore[attr-defined] + tilelang_pkg.testing = testing_pkg # type: ignore[attr-defined] + + monkeypatch.setitem(sys.modules, "tilelang", tilelang_pkg) + monkeypatch.setitem(sys.modules, "tilelang.testing", testing_pkg) + monkeypatch.setitem(sys.modules, "tilelang.testing.perf_regression", module) + + return module + + +def test_run_bench_file_executes_regressions(monkeypatch, tmp_path): + perf = _load_perf_module(monkeypatch) + bench_file = tmp_path / "regression_sample.py" + bench_file.write_text( + "import tilelang.testing\n" + "\n" + "def regression_sample():\n" + " tilelang.testing.process_func(lambda: 1.0, 'sample')\n", + encoding="utf-8", + ) + + perf._reset_results() + perf._run_bench_file(bench_file) + + assert perf._results_to_jsonable() == [{"name": "sample", "latency": 1.0}] + + +def test_regression_all_uses_pytest_wrapper(monkeypatch, tmp_path): + perf = _load_perf_module(monkeypatch) + bench_file = tmp_path / "regression_sample.py" + bench_file.write_text( + "import tilelang.testing\n" + "\n" + "def regression_sample():\n" + " tilelang.testing.process_func(lambda: 2.5, 'sample')\n", + encoding="utf-8", + ) + + calls: dict[str, list[str]] = {} + + def fake_pytest_main(args, _plugins=None): + # _plugins unused in mock; kept for signature compatibility with pytest.main + calls["args"] = args + module_vars = runpy.run_path(args[0]) + for name, fn in module_vars.items(): + if name.startswith("test_perf_regression_") and callable(fn): + fn() + return 0 + + monkeypatch.setitem(sys.modules, "pytest", types.SimpleNamespace(main=fake_pytest_main)) + + perf._reset_results() + perf.regression_all(examples_root=tmp_path) + + assert Path(calls["args"][0]).name.startswith("test_perf_regression_wrapper") + assert perf._results_to_jsonable() == [{"name": "sample", "latency": 2.5}] diff --git a/tilelang/testing/__init__.py b/tilelang/testing/__init__.py index 635fad365..634dc9430 100644 --- a/tilelang/testing/__init__.py +++ b/tilelang/testing/__init__.py @@ -8,6 +8,7 @@ from tvm.testing.utils import requires_cuda, requires_package, requires_llvm, requires_metal, requires_rocm, _compose from tilelang.utils.tensor import torch_assert_close as torch_assert_close +from .perf_regression import regression_all, process_func, regression __all__ = [ "requires_package", @@ -118,4 +119,4 @@ def requires_cuda_compute_version_lt(major_version, minor_version=0): def requires_cuda_compute_version_le(major_version, minor_version=0): - return requires_cuda_compute_version(major_version, minor_version, mode="le") + return requires_cuda_compute_version(major_version, minor_version, mode="le") \ No newline at end of file diff --git a/tilelang/testing/perf_regression.py b/tilelang/testing/perf_regression.py new file mode 100644 index 000000000..3cc4b3aed --- /dev/null +++ b/tilelang/testing/perf_regression.py @@ -0,0 +1,270 @@ +from __future__ import annotations + +import contextlib +import importlib.util +import hashlib +import inspect +import json +import os +import sys +from dataclasses import dataclass +from pathlib import Path +import tempfile +from typing import Any, Callable, Iterable, Sequence + +try: + from tabulate import tabulate +except Exception: # pragma: no cover + tabulate = None # type: ignore + +@dataclass(frozen=True) +class PerfResult: + name: str + latency: float + + +_RESULTS: list[PerfResult] = [] + + +_RESULTS_JSON_PREFIX = "__TILELANG_PERF_RESULTS_JSON__=" + + +def _results_to_jsonable() -> list[dict[str, float | str]]: + return [{"name": r.name, "latency": r.latency} for r in _RESULTS] + + +def _emit_results() -> None: + """Emit results for parent collectors. + + Default output remains the historical text format. Set + `TL_PERF_REGRESSION_FORMAT=json` to emit a single JSON marker line which is + robust against extra prints from benchmark code. + """ + fmt = os.environ.get("TL_PERF_REGRESSION_FORMAT", "text").strip().lower() + if fmt == "json": + print(_RESULTS_JSON_PREFIX + json.dumps(_results_to_jsonable(), separators=(",", ":"))) + return + # Fallback (human-readable): one result per line. + for r in _RESULTS: + print(f"{r.name}: {r.latency}") + + +def _reset_results() -> None: + _RESULTS.clear() + + +@contextlib.contextmanager +def _pushd(path: Path) -> Iterable[None]: + """Temporarily change working directory (process-wide; avoid in concurrent contexts).""" + cwd = Path.cwd() + os.chdir(path) + try: + yield + finally: + os.chdir(cwd) + + +@contextlib.contextmanager +def _prepend_sys_path(path: Path) -> Iterable[None]: + orig = list(sys.path) + sys.path.insert(0, str(path)) + try: + yield + finally: + sys.path[:] = orig + + +def _iter_regression_functions(namespace: dict[str, Any], prefixes: Sequence[str]) -> Iterable[tuple[str, Callable[..., Any]]]: + for k, v in namespace.items(): + if not callable(v): + continue + if any(k.startswith(p) for p in prefixes): + yield k, v + + +def _run_bench_file(bench_file: Path, *, prefixes: Sequence[str] = ("regression_",)) -> None: + bench_file = bench_file.resolve() + if not bench_file.is_file(): + raise FileNotFoundError(f"Benchmark driver not found: {bench_file}") + + with _pushd(bench_file.parent), _prepend_sys_path(bench_file.parent): + module_tag = hashlib.sha256(str(bench_file).encode("utf-8")).hexdigest()[:12] + parent_stem = bench_file.parent.name.replace("-", "_") or "root" + stem = bench_file.stem.replace("-", "_") + module_name = f"tilelang.testing.perf_regression.bench_{parent_stem}_{stem}_{module_tag}" + spec = importlib.util.spec_from_file_location(module_name, bench_file) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot import benchmark driver: {bench_file}") + module = importlib.util.module_from_spec(spec) + prev = sys.modules.get(module_name) + sys.modules[module_name] = module + try: + spec.loader.exec_module(module) + + for _, fn in sorted(_iter_regression_functions(module.__dict__, prefixes), key=lambda kv: kv[0]): + fn() + finally: + if prev is None: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = prev + + +def _build_pytest_wrapper(bench_files: Sequence[Path]) -> str: + lines = [ + "from pathlib import Path", + "import tilelang.testing.perf_regression as _pr", + "", + "def _make_test(path_str):", + " path = Path(path_str)", + " def _inner():", + " _pr._run_bench_file(path)", + " return _inner", + "", + ] + + for idx, bench in enumerate(bench_files): + lines.append(f"test_perf_regression_{idx} = _make_test({str(bench)!r})") + + lines.append("") + return "\n".join(lines) + + +def process_func(func: Callable[..., float], name: str | None = None, /, **kwargs: Any) -> float: + """Execute a single perf function and record its latency. + + `func` is expected to return a positive latency scalar (seconds or ms; we + treat it as an opaque number, only ratios matter for regression). + """ + result_name = getattr(func, "__module__", "") if name is None else name + if result_name.startswith("regression_"): + result_name = result_name[len("regression_") :] + latency = float(func(**kwargs)) + if not (latency > 0.0): + raise ValueError(f"Invalid latency from {result_name}: {latency}") + _RESULTS.append(PerfResult(name=result_name, latency=latency)) + return latency + + +def regression(prefixes: Sequence[str] = ("regression_",)) -> None: + """Run entrypoints in the caller module and print a markdown table. + + This is invoked by many example scripts. + """ + + caller_globals = inspect.currentframe().f_back.f_globals # type: ignore[union-attr] + + _reset_results() + functions: list[tuple[str, Callable[[], Any]]] = [] + for k, v in list(caller_globals.items()): + if not callable(v): + continue + if any(k.startswith(p) for p in prefixes): + functions.append((k, v)) + + for _, fn in sorted(functions, key=lambda kv: kv[0]): + fn() + + _emit_results() + +def _parse_table(output: str) -> dict[str, float]: + # Prefer a single JSON marker line if present. + for line in reversed(output.splitlines()): + if line.startswith(_RESULTS_JSON_PREFIX): + payload = line[len(_RESULTS_JSON_PREFIX) :].strip() + items = json.loads(payload) + data: dict[str, float] = {} + for item in items: + name = str(item["name"]).strip() + latency = float(item["latency"]) + data[name] = latency + return data + + # Backward-compatible text parsing (best-effort). + data = {} + for line in output.splitlines(): + line = line.strip() + if not line or ":" not in line: + continue + name, _, val = line.partition(":") + name = name.strip() + val = val.strip() + if not name: + continue + try: + data[name] = float(val) + except ValueError: + # Ignore unrelated prints/logs. + continue + return data + + +def _examples_root() -> Path: + # repo_root/tilelang/testing/perf_regression.py -> repo_root + return Path(__file__).resolve().parents[2] / "examples" + + +def _discover_bench_files(examples_root: Path) -> list[Path]: + patterns = ("regression_*.py",) + files: list[Path] = [] + for pat in patterns: + files.extend(examples_root.rglob(pat)) + # Avoid picking up things like __pycache__ etc. + return sorted({p for p in files if p.is_file() and p.name != "__init__.py"}) + + +def regression_all(examples_root: str | os.PathLike[str] | None = None, *, pytest_args: Sequence[str] | None = None) -> None: + """Run all example benchmark drivers and print a consolidated table. + + Intended usage (CI): `python -c "import tilelang.testing.perf_regression as pr; pr.regression_all()"` + Additional pytest arguments can be passed via `pytest_args`. + """ + + root = Path(examples_root) if examples_root is not None else _examples_root() + if not root.exists(): + raise FileNotFoundError(f"Examples root not found: {root}") + + bench_files = [p.resolve() for p in _discover_bench_files(root)] + if not bench_files: + raise RuntimeError(f"No benchmark drivers found under: {root}") + + _reset_results() + wrapper_source = _build_pytest_wrapper(bench_files) + merged: dict[str, float] = {} + with tempfile.TemporaryDirectory() as td: + wrapper = Path(td) / "test_perf_regression_wrapper.py" + wrapper.write_text(wrapper_source, encoding="utf-8") + + try: + import pytest # type: ignore + except ImportError as exc: # pragma: no cover - tested via stubbed import + raise RuntimeError("pytest is required to run perf regression suite. Install with: pip install pytest") from exc + + # Disable output capturing so benchmark progress remains visible. + args = [str(wrapper), "-s"] + if pytest_args: + args.extend(pytest_args) + + exit_code = pytest.main(args) + + for res in _RESULTS: + if res.name not in merged: + merged[res.name] = res.latency + + if not merged: + if exit_code != 0: + raise RuntimeError("All benchmark drivers failed") + raise RuntimeError("No benchmark results collected") + if exit_code != 0: + # Don't hard-fail if we have some results; pytest already reported details. + print("# Some benchmark drivers failed (partial results)") + + rows = [[k, merged[k]] for k in sorted(merged.keys())] + headers = ["File", "Latency"] + if tabulate is None: + print(f"| {headers[0]} | {headers[1]} |") + print("|---|---|") + for name, latency in rows: + print(f"| {name} | {latency} |") + else: + print(tabulate(rows, headers=headers, tablefmt="github", stralign="left", numalign="decimal"))