Skip to content

Commit 9c942d6

Browse files
committed
Fixed vllm compatibility, added more perf improvement
1 parent ad39966 commit 9c942d6

File tree

4 files changed

+25
-30
lines changed

4 files changed

+25
-30
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3222,7 +3222,7 @@ def codegen_body(self):
32223222
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
32233223
)
32243224
self.body.writeline(
3225-
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
3225+
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK, num_stages = 2):"
32263226
)
32273227
with self.body.indent(offset=level + 1):
32283228
self.iteration_ranges_codegen_header(tree, self.body)

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1138,7 +1138,7 @@ class triton:
11381138
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
11391139
# Raise the threshold to 16 to be safe.
11401140
# We should revisit this once we understand more of the source of register spills.
1141-
spill_threshold: int = 16
1141+
spill_threshold: int = 32
11421142

11431143
# Generate code containing the newer tl.make_block_ptr() API for loads/store
11441144
use_block_ptr = False

torch/_inductor/kernel/mm_scaled.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,9 @@ def scaled_mm_options( # type: ignore[no-untyped-def]
469469
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
470470
)
471471
return dict(
472-
GROUP_M=8,
472+
# this change is incompatible with vllm, can't make it into our release
473+
# should be fixed by them
474+
# GROUP_M=8,
473475
EVEN_K=even_k_symbolic,
474476
ACC_TYPE="tl.float32",
475477
USE_FAST_ACCUM=use_fast_accum,

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
577577
# for some (complicated) custom Triton kernels, a register-spilling
578578
# config may yield the best latency.
579579
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
580-
"spill_threshold", 16
580+
"spill_threshold", 32
581581
):
582582
log.debug(
583583
"Skip config %s because of register spilling: %d",
@@ -1874,11 +1874,8 @@ def pointwise(
18741874
triton_config_with_settings(
18751875
size_hints, bs // 2, num_elements_per_warp=64
18761876
),
1877-
# triton_config_with_settings(
1878-
# size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
1879-
# ),
18801877
triton_config_with_settings(
1881-
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
1878+
size_hints, TRITON_MAX_BLOCK["X"]
18821879
),
18831880
*hinted_configs,
18841881
]
@@ -1975,14 +1972,14 @@ def _reduction_configs(
19751972
if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"):
19761973
pass # skip all these cases
19771974
elif reduction_hint == ReductionHint.INNER:
1978-
return [contiguous_config]
1975+
result_configs = [contiguous_config]
19791976
elif reduction_hint == ReductionHint.OUTER:
1980-
return [outer_config]
1977+
result_configs = [outer_config]
19811978
elif reduction_hint == ReductionHint.OUTER_TINY:
1982-
return [tiny_config]
1979+
result_configs = [tiny_config]
19831980
if disable_pointwise_autotuning(inductor_meta):
1984-
return [triton_config_reduction(size_hints, 32, 128)]
1985-
return [
1981+
result_configs = [triton_config_reduction(size_hints, 32, 128)]
1982+
result_configs = [
19861983
contiguous_config,
19871984
outer_config,
19881985
tiny_config,
@@ -1994,6 +1991,19 @@ def _reduction_configs(
19941991
triton_config_reduction(size_hints, 64, 4, num_warps=8),
19951992
]
19961993

1994+
# Additional reduction configs appended for ROCm builds
1995+
if torch.version.hip:
1996+
# New config
1997+
result_configs.append(triton_config_reduction(
1998+
size_hints,
1999+
1024,
2000+
8,
2001+
num_warps=4,
2002+
num_stages=1
2003+
))
2004+
2005+
return result_configs
2006+
19972007

19982008
def reduction(
19992009
size_hints,
@@ -2012,23 +2022,6 @@ def reduction(
20122022

20132023
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
20142024

2015-
# Additional tuning confirgs for ROCm builds
2016-
# Add checks for reduction autotuning bools
2017-
# if torch.version.hip and inductor_meta.get("max_autotune"):
2018-
# configs = [
2019-
# triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
2020-
# triton_config_with_settings(
2021-
# size_hints, bs // 2, num_elements_per_warp=64
2022-
# ),
2023-
# # triton_config_with_settings(
2024-
# # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2025-
# # ),
2026-
# triton_config_with_settings(
2027-
# size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
2028-
# ),
2029-
# *hinted_configs,
2030-
# ]
2031-
20322025
return cached_autotune(
20332026
size_hints,
20342027
configs=configs,

0 commit comments

Comments
 (0)