|
| 1 | +# Autotuning |
| 2 | + |
| 3 | +TileLang includes a built‑in autotuner that searches configuration spaces |
| 4 | +for the best performing kernel, compiles candidates in parallel, validates |
| 5 | +correctness, benchmarks them, and caches the best result for reuse. |
| 6 | + |
| 7 | +This guide covers two workflows: |
| 8 | +- Decorator‑based: `@tilelang.autotune(configs=...)` stacked on `@tilelang.jit` |
| 9 | +- Programmatic: `AutoTuner.from_kernel(...).set_*().run()` |
| 10 | + |
| 11 | +It also explains input tensor supply, validation, caching, and environment |
| 12 | +variables that affect parallelism and cache behavior. |
| 13 | + |
| 14 | +## 1) Decorator‑based Autotune |
| 15 | + |
| 16 | +Use `@tilelang.autotune` above `@tilelang.jit` and expose tunable parameters as |
| 17 | +function arguments with defaults. The autotuner overrides these parameters with |
| 18 | +values from your config space. |
| 19 | + |
| 20 | +```python |
| 21 | +import tilelang |
| 22 | +import tilelang.language as T |
| 23 | + |
| 24 | +def matmul_configs(M, N, K): |
| 25 | + # Example space — tailor to your target |
| 26 | + tiles = [64, 128] |
| 27 | + stages = [2, 3] |
| 28 | + threads = [128, 256] |
| 29 | + return [ |
| 30 | + dict(block_M=BM, block_N=BN, block_K=BK, num_stages=S, threads=TH) |
| 31 | + for BM in tiles |
| 32 | + for BN in tiles |
| 33 | + for BK in [32, 64] |
| 34 | + for S in stages |
| 35 | + for TH in threads |
| 36 | + ] |
| 37 | + |
| 38 | +@tilelang.autotune(configs=matmul_configs, warmup=25, rep=100, timeout=60) |
| 39 | +@tilelang.jit(out_idx=[-1]) |
| 40 | +def matmul(M: int, N: int, K: int, |
| 41 | + block_M: int = 128, block_N: int = 128, block_K: int = 32, |
| 42 | + threads: int = 128, num_stages: int = 3, |
| 43 | + dtype: str = 'float16', accum_dtype: str = 'float32'): |
| 44 | + |
| 45 | + @T.prim_func |
| 46 | + def kernel(A: T.Tensor((M, K), dtype), |
| 47 | + B: T.Tensor((K, N), dtype), |
| 48 | + C: T.Tensor((M, N), dtype)): |
| 49 | + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): |
| 50 | + A_s = T.alloc_shared((block_M, block_K), dtype) |
| 51 | + B_s = T.alloc_shared((block_K, block_N), dtype) |
| 52 | + C_f = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 53 | + T.clear(C_f) |
| 54 | + |
| 55 | + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): |
| 56 | + T.copy(A[by * block_M, ko * block_K], A_s) |
| 57 | + T.copy(B[ko * block_K, bx * block_N], B_s) |
| 58 | + T.gemm(A_s, B_s, C_f) |
| 59 | + |
| 60 | + T.copy(C_f, C[by * block_M, bx * block_N]) |
| 61 | + |
| 62 | + return kernel |
| 63 | + |
| 64 | +# Usage |
| 65 | +# Provide inputs via context (recommended for reproducibility across configs) |
| 66 | +import torch |
| 67 | +M = N = K = 1024 |
| 68 | +A = torch.randn(M, K, device='cuda', dtype=torch.float16) |
| 69 | +B = torch.randn(K, N, device='cuda', dtype=torch.float16) |
| 70 | +C = torch.empty(M, N, device='cuda', dtype=torch.float16) |
| 71 | + |
| 72 | +from tilelang.autotuner import set_autotune_inputs |
| 73 | +with set_autotune_inputs(A, B, C): |
| 74 | + tuned_kernel = matmul(M, N, K) # compiles, tunes, returns best kernel |
| 75 | + tuned_kernel(A, B, C) # run best kernel |
| 76 | +``` |
| 77 | + |
| 78 | +Notes |
| 79 | +- `configs` can be a list of dicts or a callable `(args...) -> list[dict]`. Each |
| 80 | + dict’s keys must match the tunable function arguments (e.g., `block_M`). |
| 81 | +- The decorator returns a callable that runs autotune once per argument tuple |
| 82 | + and caches the resulting best kernel in‑process. |
| 83 | +- For explicit input control during tuning, wrap the call with |
| 84 | + `set_autotune_inputs(...)`. Otherwise, `supply_type` (below) is used. |
| 85 | + |
| 86 | +## 2) Programmatic Autotune |
| 87 | + |
| 88 | +Use the `AutoTuner` class to manage configs and arguments more explicitly. |
| 89 | + |
| 90 | +```python |
| 91 | +from tilelang.autotuner import AutoTuner |
| 92 | + |
| 93 | +kernel_factory = matmul # the function above (already @tilelang.jit) |
| 94 | +tuner = AutoTuner.from_kernel(kernel_factory(M, N, K), configs=matmul_configs(M, N, K)) |
| 95 | + |
| 96 | +tuner.set_profile_args( |
| 97 | + warmup=25, rep=100, timeout=60, |
| 98 | + supply_type=tilelang.TensorSupplyType.Auto, # or provide supply_prog/ref_prog |
| 99 | + ref_prog=lambda A, B, C: torch.allclose(C, (A @ B).to(C.dtype), rtol=1e-2, atol=1e-2), |
| 100 | +) |
| 101 | + |
| 102 | +tuner.set_compile_args( |
| 103 | + target='auto', # or 'cuda'/'hip'/'metal' |
| 104 | + execution_backend='auto', # resolves per-target |
| 105 | + out_idx=[-1], # which outputs to return if multiple |
| 106 | + pass_configs={ # optional TVM passes/flags |
| 107 | + # tilelang.PassConfigKey.EXAMPLE_KEY: value, |
| 108 | + }, |
| 109 | +) |
| 110 | + |
| 111 | +artifact = tuner.run() # compiles + runs + validates all configs |
| 112 | +best_kernel = artifact.kernel # JITKernel |
| 113 | +best_latency = artifact.latency |
| 114 | +best_config = artifact.config |
| 115 | + |
| 116 | +# Reuse best kernel |
| 117 | +best_kernel(A, B, C) |
| 118 | +``` |
| 119 | + |
| 120 | +### Example Gallery (in repo) |
| 121 | +- examples/gdn/example_chunk_delta_h.py:101 — uses `@autotune` to sweep configs |
| 122 | +- examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py:451 — uses `@tilelang.autotune` |
| 123 | +- examples/quickstart.py:84 — profiles a tuned kernel with `get_profiler` |
| 124 | +- examples/hadamard_transform/example_hadamard.py:152 — profiler with custom warmup |
| 125 | +- examples/dynamic_shape/example_dynamic.py:94 — profiler for dynamic shapes |
| 126 | +- examples/gemm/example_gemm_persistent.py:135 — compare persistent vs non‑persistent |
| 127 | + |
| 128 | +Click any path to open the code and compare patterns. |
| 129 | + |
| 130 | +## Input Tensor Supply |
| 131 | + |
| 132 | +The tuner needs inputs to compile and benchmark kernels. Provide them in one of |
| 133 | +three ways (priority order): |
| 134 | + |
| 135 | +1) Context manager (fixed inputs across configs) |
| 136 | +```python |
| 137 | +with set_autotune_inputs(A, B, C): |
| 138 | + tuned = matmul(M, N, K) |
| 139 | +``` |
| 140 | + |
| 141 | +2) Custom supplier program |
| 142 | +```python |
| 143 | +def supply_prog(signature): |
| 144 | + # signature holds KernelParam objects describing shapes/dtypes |
| 145 | + # Return a list of torch tensors matching the kernel’s arguments |
| 146 | + return [A, B, C] |
| 147 | + |
| 148 | +tuner.set_profile_args(supply_prog=supply_prog) |
| 149 | +``` |
| 150 | + |
| 151 | +3) Built‑in generators via `supply_type` |
| 152 | +- `TensorSupplyType.Auto` (default): heuristic per dtype (uniform ints / fp ranges) |
| 153 | +- `Integer`, `Uniform`, `Normal`, `Randn`, `Zero`, `One` |
| 154 | + |
| 155 | +Important |
| 156 | +- Built‑in generators require static shapes; if your PrimFunc uses symbolic |
| 157 | + dimensions (T.dyn), supply concrete inputs via (1) or (2). |
| 158 | +- Float8 dtypes require PyTorch 2.1+ for `torch.float8_*` support. |
| 159 | + |
| 160 | +## Correctness Checking and Tolerances |
| 161 | + |
| 162 | +Use one of the following validation methods: |
| 163 | +- `ref_prog`: Provide a reference program that receives the same inputs and |
| 164 | + checks results. You can return a boolean or raise on mismatch. |
| 165 | +- `manual_check_prog`: A callable that inspects outputs and raises on mismatch. |
| 166 | +- `skip_check=True`: Skip correctness checks (faster, use with caution). |
| 167 | + |
| 168 | +Control numeric drift via: |
| 169 | +- `rtol` and `atol` (defaults 1e‑2) |
| 170 | +- `max_mismatched_ratio` (default 1%) |
| 171 | + |
| 172 | +## Configuration Spaces and Best Practices |
| 173 | + |
| 174 | +What to tune |
| 175 | +- Tile sizes: `block_M`, `block_N`, `block_K` |
| 176 | +- Software pipelining: `num_stages` |
| 177 | +- Threads per block: `threads` (or (x, y) tuple) |
| 178 | +- Optional: dtype variants, epilogues, small scheduling knobs |
| 179 | + |
| 180 | +Tips |
| 181 | +- Start from a working baseline. Tune a small, meaningful space first. |
| 182 | +- Respect hardware limits (shared memory bytes, registers per thread/block, |
| 183 | + max threads per block). Eliminate impossible configs up‑front. |
| 184 | +- Keep block sizes multiples of vector widths and warp sizes when relevant. |
| 185 | +- Use `set_autotune_inputs` to ensure each config is measured on identical data. |
| 186 | +- Record your best configs and bake them as defaults when stable. |
| 187 | + |
| 188 | +## Parallel Compilation/Benchmarking and Timeouts |
| 189 | + |
| 190 | +The tuner compiles configurations in parallel using a thread pool and benchmarks |
| 191 | +them with a per‑config timeout. On CUDA, each worker sets the current device to |
| 192 | +avoid context issues. |
| 193 | + |
| 194 | +Notes |
| 195 | +- `timeout` uses POSIX signals; on non‑Unix systems, it may not take effect. |
| 196 | +- Logs are written to `autotuner.log` in the working directory. |
| 197 | + |
| 198 | +## Caching |
| 199 | + |
| 200 | +The autotuner caches best artifacts both in‑memory (per process) and on disk under |
| 201 | +`$TILELANG_CACHE_DIR/autotuner`. The cache key includes: |
| 202 | +- TileLang version, function source, closure free‑vars |
| 203 | +- Config list, compile args, profile args |
| 204 | + |
| 205 | +Disk cache contents (per key) |
| 206 | +- Best config and latency: `best_config.json`, `latency.json` |
| 207 | +- Kernel sources and library: `device_kernel.cu`, `host_kernel.cu`, `kernel_lib.so` (or `kernel.cubin`/`executable.so` depending on backend) |
| 208 | +- Function and params: `function.pkl`, `params.pkl` |
| 209 | + |
| 210 | +Control via env vars (tilelang.env) |
| 211 | +- `TILELANG_CACHE_DIR` (default `~/.tilelang/cache`) |
| 212 | +- `TILELANG_TMP_DIR` (default `$TILELANG_CACHE_DIR/tmp`) |
| 213 | +- Disable all kernel caches: `TILELANG_DISABLE_CACHE=1` |
| 214 | +- Disable autotune disk cache only: `TILELANG_AUTO_TUNING_DISABLE_CACHE=1` |
| 215 | + |
| 216 | +CPU worker control |
| 217 | +- `TILELANG_AUTO_TUNING_CPU_UTILITIES` (fraction, default 0.9) |
| 218 | +- `TILELANG_AUTO_TUNING_CPU_COUNTS` (int, `-1` auto) |
| 219 | +- `TILELANG_AUTO_TUNING_MAX_CPU_COUNT` (int, `-1` unlimited) |
| 220 | + |
| 221 | +Backend notes |
| 222 | +- NVRTC backend persists `.cubin` and a Python launcher. |
| 223 | +- Torch/DLPack backend may not save artifacts to disk; in this case, only |
| 224 | + in‑memory caching applies and a warning is logged. |
| 225 | + |
| 226 | +## Alternative: Manual Sweeps with par_compile |
| 227 | + |
| 228 | +If you prefer manual control, use `JITImpl.par_compile` to compile a batch of |
| 229 | +configs and drive your own benchmarking: |
| 230 | + |
| 231 | +```python |
| 232 | +@tilelang.jit |
| 233 | +def factory(M, N, K, block_M=128, block_N=128, block_K=32): |
| 234 | + @T.prim_func |
| 235 | + def k(A: T.Tensor((M, K), 'float16'), |
| 236 | + B: T.Tensor((K, N), 'float16'), |
| 237 | + C: T.Tensor((M, N), 'float16')): |
| 238 | + ... |
| 239 | + return k |
| 240 | + |
| 241 | +impl = factory # JITImpl |
| 242 | +cfgs = [ |
| 243 | + dict(block_M=64, block_N=128, block_K=32), |
| 244 | + dict(block_M=128, block_N=128, block_K=64), |
| 245 | +] |
| 246 | +kernels = impl.par_compile(cfgs, num_workers=4) |
| 247 | +# Now benchmark kernels[i](A, B, C) yourself |
| 248 | +``` |
| 249 | + |
| 250 | +## Recording and Reusing Best Configs |
| 251 | + |
| 252 | +The programmatic path returns an `AutotuneResult` that can be saved and later |
| 253 | +reloaded. This is useful for CI, multi‑host workflows, or shipping tuned configs. |
| 254 | + |
| 255 | +```python |
| 256 | +artifact = tuner.run() # AutotuneResult |
| 257 | + |
| 258 | +# Save to disk |
| 259 | +from pathlib import Path |
| 260 | +save_dir = Path('out/best/matmul_1024') |
| 261 | +artifact.save_to_disk(save_dir, verbose=True) |
| 262 | + |
| 263 | +# Reload later |
| 264 | +from tilelang.autotuner.param import AutotuneResult, CompileArgs |
| 265 | +restored = AutotuneResult.load_from_disk(save_dir, CompileArgs()) |
| 266 | +best = restored.kernel |
| 267 | +best(A, B, C) |
| 268 | +``` |
| 269 | + |
| 270 | +Notes |
| 271 | +- DLPack/Torch execution backend may not persist compiled binaries; in that |
| 272 | + case, re‑compilation is needed on load or use a different backend. |
| 273 | +- The directory contains human‑readable JSONs (best config/latency) and sources. |
| 274 | + |
| 275 | +## Advanced: Config Space Callables |
| 276 | + |
| 277 | +Derive config spaces from problem sizes to keep searches targeted and legal: |
| 278 | + |
| 279 | +```python |
| 280 | +def matmul_configs(M, N, K): |
| 281 | + large = min(M, N, K) >= 1024 |
| 282 | + tiles = [128] if large else [64, 128] |
| 283 | + for BM in tiles: |
| 284 | + for BN in tiles: |
| 285 | + for BK in [32, 64]: |
| 286 | + for S in [2, 3]: |
| 287 | + for TH in [128, 256]: |
| 288 | + yield dict(block_M=BM, block_N=BN, block_K=BK, |
| 289 | + num_stages=S, threads=TH) |
| 290 | +``` |
| 291 | + |
| 292 | +## Device and Backend Selection |
| 293 | + |
| 294 | +Tune compile‑time options explicitly: |
| 295 | +- `target='auto'|'cuda'|'hip'|'metal'` (normalized to a TVM Target) |
| 296 | +- `execution_backend='auto'|'tvm_ffi'|'ctypes'|'cython'|'nvrtc'|'torch'` |
| 297 | +- `pass_configs={...}` to toggle TileLang/TVM passes for experiments |
| 298 | + |
| 299 | +On CUDA with multiple GPUs, the tuner sets the current device per worker thread |
| 300 | +to avoid context mixups. |
| 301 | + |
| 302 | +## Troubleshooting |
| 303 | +- “No configurations to tune”: Ensure `configs` is a non‑empty list or callable. |
| 304 | +- Timeouts: Increase `timeout`; ensure inputs fit device memory; verify that |
| 305 | + your reference check isn’t the bottleneck. |
| 306 | +- Dynamic shapes: Provide concrete inputs via `set_autotune_inputs` or a custom |
| 307 | + `supply_prog`. |
| 308 | +- Disk cache disabled: Check `TILELANG_AUTO_TUNING_DISABLE_CACHE` and backend. |
0 commit comments