Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,11 +1083,17 @@ def relu(x):

@staticmethod
def minimum(a, b):
return f"triton_helpers.minimum({a}, {b})"
if torch.version.hip:
return f"tl.minimum({a}, {b})"
else:
return f"triton_helpers.minimum({a}, {b})"

@staticmethod
def maximum(a, b):
return f"triton_helpers.maximum({a}, {b})"
if torch.version.hip:
return f"tl.maximum({a}, {b})"
else:
return f"tl.maximum({a}, {b})"

@staticmethod
def where(a, b, c):
Expand Down Expand Up @@ -1273,7 +1279,10 @@ def load_seed(name, offset):
@staticmethod
@maybe_upcast_float32()
def rsqrt(x):
return f"libdevice.rsqrt({x})"
if torch.version.hip:
return f"tl.rsqrt({x})"
else:
return f"libdevice.rsqrt({x})"

@staticmethod
@maybe_upcast_float32()
Expand Down Expand Up @@ -3598,9 +3607,14 @@ def codegen_body(self):
loop_end = (
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
)
self.body.writeline(
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
)
if torch.version.hip:
self.body.writeline(
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK, num_stages = 2):"
)
else:
self.body.writeline(
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
)
with self.body.indent(offset=level + 1):
self.iteration_ranges_codegen_header(tree, self.body)

Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ class triton:
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
# Raise the threshold to 16 to be safe.
# We should revisit this once we understand more of the source of register spills.
spill_threshold: int = 16
spill_threshold: int = 32

# Generate code containing the newer tl.make_block_ptr() API for loads/store
use_block_ptr = False
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/runtime/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
"X": 8192,
"Y": 1024,
"Z": 1024,
"R0_": 4096 * 16, # * 16 is multi-kernel only
Expand Down
148 changes: 92 additions & 56 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
# for some (complicated) custom Triton kernels, a register-spilling
# config may yield the best latency.
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
"spill_threshold", 16
"spill_threshold", 32
):
log.debug(
"Skip config %s because of register spilling: %d",
Expand Down Expand Up @@ -2071,6 +2071,9 @@ def triton_config(
num_stages=1,
num_elements_per_warp=256,
min_elem_per_thread=0,
num_warps=None,
matrix_instr=None,
waves_per_eu=None
) -> Config:
"""
Construct a pointwise triton config with some adjustment heuristics
Expand Down Expand Up @@ -2127,9 +2130,11 @@ def triton_config(
):
z *= 2

num_warps = _num_warps(
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
)
# Calculate num_waprs if they are not hard passed to config
if num_warps == None:
num_warps = _num_warps(
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
)
# we are going to arrive at 2 warps only if bs was too small due to
# numel being too small. However to workaround some ptx bugs we still
# want at least 4 warps if there's enough elements per thread
Expand Down Expand Up @@ -2159,7 +2164,15 @@ def triton_config(
cfg["ZBLOCK"] = z
check_max_block(cfg)
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)

if torch.version.hip:
if matrix_instr is not None:
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu

return config


def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
Expand Down Expand Up @@ -2207,6 +2220,7 @@ def triton_config_reduction(
num_stages=1,
num_warps=None,
register_intensive=False,
waves_per_eu=None
) -> Config:
"""
Construct a reduction triton config with some adjustment heuristics
Expand Down Expand Up @@ -2250,7 +2264,13 @@ def total_numel() -> int:
cfg = _get_config({"x": x, **rnumels})
check_max_block(cfg)
check_config(cfg, xnumel=size_hints["x"])
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)

if torch.version.hip:
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu

return config


def _get_config(numels: dict[str, int]) -> dict[str, int]:
Expand All @@ -2262,7 +2282,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:


def triton_config_tiled_reduction(
size_hints, x, y, r, num_stages=1, register_intensive=False
size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None
):
"""
Construct a tile reduction triton config with some adjustment
Expand Down Expand Up @@ -2299,7 +2319,11 @@ def total_numel() -> int:
)
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
check_max_block(cfg)
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
if torch.version.hip:
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu
return config


def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]):
Expand Down Expand Up @@ -2388,6 +2412,9 @@ def pointwise(
triton_config_with_settings(
size_hints, bs // 2, num_elements_per_warp=64
),
triton_config_with_settings(
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
),
*hinted_configs,
]
if len(size_hints) == 2:
Expand Down Expand Up @@ -2446,6 +2473,9 @@ def _reduction_configs(
# Convert reductions to 1D, to simplify heuristics.
rnumel = get_total_reduction_numel(size_hints)

# Is max autotune enabled
max_autotune = inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")

register_intensive = False
MAX_R0_BLOCK = 2048
if (
Expand All @@ -2468,7 +2498,7 @@ def _reduction_configs(
MAX_R0_BLOCK = 1024
register_intensive = True

def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, waves_per_eu=None):
# For 3D case with tiling scores, create an adapted version
if "y" in size_hints:
assert "tiling_scores" in inductor_meta
Expand All @@ -2480,6 +2510,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
num_warps=num_warps,
num_stages=num_stages,
register_intensive=register_intensive,
waves_per_eu=waves_per_eu
)
else:
# For other cases, use the original function
Expand All @@ -2490,6 +2521,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
num_warps=num_warps,
num_stages=num_stages,
register_intensive=register_intensive,
waves_per_eu=waves_per_eu
)

contiguous_config = make_config(
Expand All @@ -2503,32 +2535,39 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
min(rnumel, MAX_R0_BLOCK),
register_intensive=register_intensive,
)
# For 3d tiling, default to more autotuning initially
if "y" in size_hints:
pass
elif inductor_meta.get("max_autotune") or inductor_meta.get(
"max_autotune_pointwise"
):
pass # skip all these cases
elif reduction_hint == ReductionHint.INNER:
return [contiguous_config]
elif reduction_hint == ReductionHint.OUTER:
return [outer_config]
elif reduction_hint == ReductionHint.OUTER_TINY:
return [tiny_config]
if disable_pointwise_autotuning(inductor_meta):
return [make_config(32, 128)]
return [
contiguous_config,
outer_config,
tiny_config,
make_config(64, 64),
make_config(8, 512),
# halve the XBLOCK/Rn_BLOCK compared to outer_config
# TODO: this may only be beneficial when each iteration of the reduction
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
make_config(64, 4, num_warps=8),
]

result_configs = []

if not (max_autotune or "y" in size_hints):
if reduction_hint == ReductionHint.INNER:
result_configs = [contiguous_config]
elif reduction_hint == ReductionHint.OUTER:
result_configs = [outer_config]
elif reduction_hint == ReductionHint.OUTER_TINY:
result_configs = [tiny_config]
else:
result_configs = [make_config(32, 128)]
else:
result_configs = [
contiguous_config,
outer_config,
tiny_config,
make_config(64, 64),
make_config(8, 512),
# halve the XBLOCK/Rn_BLOCK compared to outer_config
# TODO: this may only be beneficial when each iteration of the reduction
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
make_config(64, 4, num_warps=8),
]

# Add ROCm-specific configs when autotuning
if torch.version.hip:
result_configs.extend([
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1)
])

return result_configs


def match_target_block_product(
Expand Down Expand Up @@ -2586,6 +2625,7 @@ def adapt_config_for_tiling(
num_stages=1,
register_intensive=False,
persistent_reduction=False,
waves_per_eu=None
) -> Config:
"""
Create an adapted configuration based on tiling scores,
Expand All @@ -2604,6 +2644,7 @@ def adapt_config_for_tiling(
block_sizes["r0_"],
num_stages=num_stages,
register_intensive=register_intensive,
waves_per_eu=waves_per_eu
)


Expand All @@ -2624,6 +2665,7 @@ def reduction(

configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)

return cached_autotune(
size_hints,
configs=configs,
Expand Down Expand Up @@ -2694,18 +2736,12 @@ def _persistent_reduction_configs(
or inductor_meta.get("max_autotune_pointwise")
)

configs = [
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
for xblock in (1, 8, 32, 128)
if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096))
]

if "y" not in size_hints:
configs = [
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
for xblock in (1, 8, 32, 128)
if xblock == 1
or (rnumel * xblock <= 4096 and xblock <= xnumel)
or (xblock <= xnumel and rnumel * xblock <= 4096)
]
else:
configs = []
Expand Down Expand Up @@ -2734,20 +2770,20 @@ def _persistent_reduction_configs(
elif reduction_hint == ReductionHint.OUTER:
configs = configs[-1:]

if reduction_hint == ReductionHint.OUTER_TINY:
tiny_configs = [
triton_config_reduction(
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
rnumel,
)
]
if max_autotune_enabled:
for tconfig in tiny_configs:
if tconfig not in configs:
configs.append(tconfig)
else:
configs = tiny_configs
tiny_configs = [
triton_config_reduction(
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
rnumel,
)
]

if max_autotune_enabled:
for conf in tiny_configs:
if conf not in configs:
configs.append(conf)
elif reduction_hint == ReductionHint.OUTER_TINY:
configs = tiny_configs

for c in configs:
# we don't need Rn_BLOCK for persistent reduction
Expand Down