diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0e763772911ca..ff1755cb55534 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1083,11 +1083,17 @@ def relu(x): @staticmethod def minimum(a, b): - return f"triton_helpers.minimum({a}, {b})" + if torch.version.hip: + return f"tl.minimum({a}, {b})" + else: + return f"triton_helpers.minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"triton_helpers.maximum({a}, {b})" + if torch.version.hip: + return f"tl.maximum({a}, {b})" + else: + return f"tl.maximum({a}, {b})" @staticmethod def where(a, b, c): @@ -1273,7 +1279,10 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() def rsqrt(x): - return f"libdevice.rsqrt({x})" + if torch.version.hip: + return f"tl.rsqrt({x})" + else: + return f"libdevice.rsqrt({x})" @staticmethod @maybe_upcast_float32() @@ -3598,9 +3607,14 @@ def codegen_body(self): loop_end = ( "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" ) - self.body.writeline( - f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" - ) + if torch.version.hip: + self.body.writeline( + f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK, num_stages = 2):" + ) + else: + self.body.writeline( + f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + ) with self.body.indent(offset=level + 1): self.iteration_ranges_codegen_header(tree, self.body) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e5b5fe224cc81..9f172d8acaa39 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1304,7 +1304,7 @@ class triton: # So far we see a fixed 8 spilled registers for kernels using sin/cos. # Raise the threshold to 16 to be safe. # We should revisit this once we understand more of the source of register spills. - spill_threshold: int = 16 + spill_threshold: int = 32 # Generate code containing the newer tl.make_block_ptr() API for loads/store use_block_ptr = False diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 2732b9cecfb21..899b2f2d6d770 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -13,7 +13,7 @@ # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 4096, + "X": 8192, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index bbe9b04243e6c..3a39bc677a1cf 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -800,7 +800,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs): # for some (complicated) custom Triton kernels, a register-spilling # config may yield the best latency. if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( - "spill_threshold", 16 + "spill_threshold", 32 ): log.debug( "Skip config %s because of register spilling: %d", @@ -2071,6 +2071,9 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -2127,9 +2130,11 @@ def triton_config( ): z *= 2 - num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 - ) + # Calculate num_waprs if they are not hard passed to config + if num_warps == None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -2159,7 +2164,15 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: @@ -2207,6 +2220,7 @@ def triton_config_reduction( num_stages=1, num_warps=None, register_intensive=False, + waves_per_eu=None ) -> Config: """ Construct a reduction triton config with some adjustment heuristics @@ -2250,7 +2264,13 @@ def total_numel() -> int: cfg = _get_config({"x": x, **rnumels}) check_max_block(cfg) check_config(cfg, xnumel=size_hints["x"]) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_config(numels: dict[str, int]) -> dict[str, int]: @@ -2262,7 +2282,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: def triton_config_tiled_reduction( - size_hints, x, y, r, num_stages=1, register_intensive=False + size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None ): """ Construct a tile reduction triton config with some adjustment @@ -2299,7 +2319,11 @@ def total_numel() -> int: ) check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) check_max_block(cfg) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + return config def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]): @@ -2388,6 +2412,9 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), *hinted_configs, ] if len(size_hints) == 2: @@ -2446,6 +2473,9 @@ def _reduction_configs( # Convert reductions to 1D, to simplify heuristics. rnumel = get_total_reduction_numel(size_hints) + # Is max autotune enabled + max_autotune = inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + register_intensive = False MAX_R0_BLOCK = 2048 if ( @@ -2468,7 +2498,7 @@ def _reduction_configs( MAX_R0_BLOCK = 1024 register_intensive = True - def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): + def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, waves_per_eu=None): # For 3D case with tiling scores, create an adapted version if "y" in size_hints: assert "tiling_scores" in inductor_meta @@ -2480,6 +2510,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu ) else: # For other cases, use the original function @@ -2490,6 +2521,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu ) contiguous_config = make_config( @@ -2503,32 +2535,39 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) - # For 3d tiling, default to more autotuning initially - if "y" in size_hints: - pass - elif inductor_meta.get("max_autotune") or inductor_meta.get( - "max_autotune_pointwise" - ): - pass # skip all these cases - elif reduction_hint == ReductionHint.INNER: - return [contiguous_config] - elif reduction_hint == ReductionHint.OUTER: - return [outer_config] - elif reduction_hint == ReductionHint.OUTER_TINY: - return [tiny_config] - if disable_pointwise_autotuning(inductor_meta): - return [make_config(32, 128)] - return [ - contiguous_config, - outer_config, - tiny_config, - make_config(64, 64), - make_config(8, 512), - # halve the XBLOCK/Rn_BLOCK compared to outer_config - # TODO: this may only be beneficial when each iteration of the reduction - # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 - make_config(64, 4, num_warps=8), - ] + + result_configs = [] + + if not (max_autotune or "y" in size_hints): + if reduction_hint == ReductionHint.INNER: + result_configs = [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + result_configs = [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + result_configs = [tiny_config] + else: + result_configs = [make_config(32, 128)] + else: + result_configs = [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + + # Add ROCm-specific configs when autotuning + if torch.version.hip: + result_configs.extend([ + make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1) + ]) + + return result_configs def match_target_block_product( @@ -2586,6 +2625,7 @@ def adapt_config_for_tiling( num_stages=1, register_intensive=False, persistent_reduction=False, + waves_per_eu=None ) -> Config: """ Create an adapted configuration based on tiling scores, @@ -2604,6 +2644,7 @@ def adapt_config_for_tiling( block_sizes["r0_"], num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu ) @@ -2624,6 +2665,7 @@ def reduction( configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) + return cached_autotune( size_hints, configs=configs, @@ -2694,18 +2736,12 @@ def _persistent_reduction_configs( or inductor_meta.get("max_autotune_pointwise") ) - configs = [ - triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) - for xblock in (1, 8, 32, 128) - if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) - ] - if "y" not in size_hints: configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) for xblock in (1, 8, 32, 128) if xblock == 1 - or (rnumel * xblock <= 4096 and xblock <= xnumel) + or (xblock <= xnumel and rnumel * xblock <= 4096) ] else: configs = [] @@ -2734,20 +2770,20 @@ def _persistent_reduction_configs( elif reduction_hint == ReductionHint.OUTER: configs = configs[-1:] - if reduction_hint == ReductionHint.OUTER_TINY: - tiny_configs = [ - triton_config_reduction( - size_hints, - 2 * (256 // rnumel) if rnumel <= 256 else 1, - rnumel, - ) - ] - if max_autotune_enabled: - for tconfig in tiny_configs: - if tconfig not in configs: - configs.append(tconfig) - else: - configs = tiny_configs + tiny_configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + + if max_autotune_enabled: + for conf in tiny_configs: + if conf not in configs: + configs.append(conf) + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = tiny_configs for c in configs: # we don't need Rn_BLOCK for persistent reduction