@@ -577,7 +577,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
577577 # for some (complicated) custom Triton kernels, a register-spilling
578578 # config may yield the best latency.
579579 if not self .custom_kernel and launcher .n_spills > self .inductor_meta .get (
580- "spill_threshold" , 16
580+ "spill_threshold" , 32
581581 ):
582582 log .debug (
583583 "Skip config %s because of register spilling: %d" ,
@@ -1874,11 +1874,8 @@ def pointwise(
18741874 triton_config_with_settings (
18751875 size_hints , bs // 2 , num_elements_per_warp = 64
18761876 ),
1877- # triton_config_with_settings(
1878- # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
1879- # ),
18801877 triton_config_with_settings (
1881- size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
1878+ size_hints , TRITON_MAX_BLOCK ["X" ]
18821879 ),
18831880 * hinted_configs ,
18841881 ]
@@ -1975,14 +1972,14 @@ def _reduction_configs(
19751972 if inductor_meta .get ("max_autotune" ) or inductor_meta .get ("max_autotune_pointwise" ):
19761973 pass # skip all these cases
19771974 elif reduction_hint == ReductionHint .INNER :
1978- return [contiguous_config ]
1975+ result_configs = [contiguous_config ]
19791976 elif reduction_hint == ReductionHint .OUTER :
1980- return [outer_config ]
1977+ result_configs = [outer_config ]
19811978 elif reduction_hint == ReductionHint .OUTER_TINY :
1982- return [tiny_config ]
1979+ result_configs = [tiny_config ]
19831980 if disable_pointwise_autotuning (inductor_meta ):
1984- return [triton_config_reduction (size_hints , 32 , 128 )]
1985- return [
1981+ result_configs = [triton_config_reduction (size_hints , 32 , 128 )]
1982+ result_configs = [
19861983 contiguous_config ,
19871984 outer_config ,
19881985 tiny_config ,
@@ -1994,6 +1991,19 @@ def _reduction_configs(
19941991 triton_config_reduction (size_hints , 64 , 4 , num_warps = 8 ),
19951992 ]
19961993
1994+ # Additional reduction configs appended for ROCm builds
1995+ if torch .version .hip :
1996+ # New config
1997+ result_configs .append (triton_config_reduction (
1998+ size_hints ,
1999+ 1024 ,
2000+ 8 ,
2001+ num_warps = 4 ,
2002+ num_stages = 1
2003+ ))
2004+
2005+ return result_configs
2006+
19972007
19982008def reduction (
19992009 size_hints ,
@@ -2012,23 +2022,6 @@ def reduction(
20122022
20132023 configs = _reduction_configs (size_hints = size_hints , inductor_meta = inductor_meta )
20142024
2015- # Additional tuning confirgs for ROCm builds
2016- # Add checks for reduction autotuning bools
2017- # if torch.version.hip and inductor_meta.get("max_autotune"):
2018- # configs = [
2019- # triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
2020- # triton_config_with_settings(
2021- # size_hints, bs // 2, num_elements_per_warp=64
2022- # ),
2023- # # triton_config_with_settings(
2024- # # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2025- # # ),
2026- # triton_config_with_settings(
2027- # size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
2028- # ),
2029- # *hinted_configs,
2030- # ]
2031-
20322025 return cached_autotune (
20332026 size_hints ,
20342027 configs = configs ,
0 commit comments