Skip to content

Commit 696e668

Browse files
pytorchmergebotChao1Han
authored andcommitted
Revert "[inductor] Expand use of generic benchmark function (pytorch#164938)"
This reverts commit 5c583e2. Reverted pytorch#164938 on behalf of https://github.com/clee2000 due to I think this broke test/inductor/test_cuda_repro.py::CudaReproTests::test_epilogue_fusion_with_view? [GH job link](https://github.com/pytorch/pytorch/actions/runs/18529735968/job/52813191763) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/f58f301313d4fc89499fb35cdfb2ffb91d14d896) on both rocm and the slow grad check for linux. It did run successfully on cuda workflow on trunk, I wonder if this a gpu capability thing? no clue though ([comment](pytorch#164938 (comment)))
1 parent f99e6bb commit 696e668

File tree

10 files changed

+45
-103
lines changed

10 files changed

+45
-103
lines changed

torch/_inductor/codegen/multi_kernel.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from torch._inductor.ir import MultiTemplateBuffer
1010
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
11-
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
1211
from torch.utils._ordered_set import OrderedSet
1312

1413
from .. import config
@@ -370,20 +369,16 @@ def benchmark_sub_kernels(self, *args, **kwargs):
370369
be picked.
371370
"""
372371

373-
def get_args_kwargs(kernel, index) -> tuple[tuple, dict[str, Any]]: # type: ignore[type-arg]
374-
filtered_args = self._get_filtered_args(args, index)
375-
args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
376-
return args_clone, kwargs_clone
372+
def wrap_fn(kernel, index):
373+
def inner():
374+
filtered_args = self._get_filtered_args(args, index)
375+
args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
376+
return kernel.run(*args_clone, **kwargs_clone)
377+
378+
return inner
377379

378380
return [
379-
benchmarker.benchmark(
380-
kernel.run,
381-
*get_args_kwargs(kernel, index),
382-
device=kernel.device_props.type
383-
if isinstance(kernel, CachingAutotuner)
384-
else None,
385-
rep=40,
386-
)
381+
benchmarker.benchmark_gpu(wrap_fn(kernel, index), rep=40)
387382
for index, kernel in enumerate(self.kernels)
388383
]
389384

torch/_inductor/codegen/subgraph.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,7 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
109109
bm_func([*sym_inputs, *args])
110110
if config.profile_bandwidth_with_do_bench_using_profiling:
111111
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
112-
return benchmarker.benchmark(
113-
bm_func,
114-
fn_args=([*sym_inputs, *args],),
115-
)
112+
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
116113

117114
def hash_key(self) -> str:
118115
return "-".join(

torch/_inductor/codegen/triton.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4682,7 +4682,7 @@ def codegen_kernel_benchmark(self, num_gb: Optional[float]) -> IndentedBuffer:
46824682

46834683
result.writeline("args = get_args()")
46844684
result.writeline(
4685-
f"ms = benchmarker.benchmark(lambda: call(args), device={V.graph.get_current_device_or_throw().type}, rep=40)" # noqa: B950 line too long
4685+
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)"
46864686
)
46874687
result.writeline(f"num_gb = {num_gb}")
46884688
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
@@ -5624,21 +5624,18 @@ def load_cache():
56245624
# skip benchmarking the kernel if there are register spills
56255625
ms = float("inf")
56265626
else:
5627-
device = V.graph.get_current_device_or_throw()
56285627
# We have to clone the inplace updated arguments to avoid earlier calls
56295628
# generating out of range indices for later calls.
5630-
ms = benchmarker.benchmark(
5631-
lambda: call(wrapped_jit_function.clone_args(*args)[0]),
5632-
device=device,
5629+
ms = benchmarker.benchmark_gpu(
5630+
lambda: call(wrapped_jit_function.clone_args(*args)[0])
56335631
)
56345632
# overhead of cloning args gives bias for fusing the kernel
56355633
# in the case of mutating/in-placeable second fusion
56365634
# TODO - would be better as a hook in triton do_bench that reset
56375635
# the input values between benchmarking
56385636
if len(wrapped_jit_function.mutated_arg_names) > 0:
5639-
ms = ms - benchmarker.benchmark(
5640-
lambda: wrapped_jit_function.clone_args(*args),
5641-
device=str(device),
5637+
ms = ms - benchmarker.benchmark_gpu(
5638+
lambda: wrapped_jit_function.clone_args(*args)
56425639
)
56435640

56445641
log.debug(
@@ -5807,16 +5804,13 @@ def store_cache():
58075804
# skip benchmarking the kernel if there are register spills
58085805
ms = ms_clone = float("inf")
58095806
else:
5810-
device = V.graph.get_current_device_or_throw()
58115807
# We have to clone the inplace updated arguments to avoid earlier calls
58125808
# generating out of range indices for later calls.
5813-
ms = benchmarker.benchmark(
5814-
lambda: call(wrapped_jit_function.clone_args(*args)[0]),
5815-
device=device,
5809+
ms = benchmarker.benchmark_gpu(
5810+
lambda: call(wrapped_jit_function.clone_args(*args)[0])
58165811
)
5817-
ms_clone = benchmarker.benchmark(
5818-
lambda: wrapped_jit_function.clone_args(*args)[0],
5819-
device=device,
5812+
ms_clone = benchmarker.benchmark_gpu(
5813+
lambda: wrapped_jit_function.clone_args(*args)[0]
58205814
)
58215815

58225816
log.debug(

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,6 @@ def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer:
889889
result.writeline(f"return {', '.join(var_names)},")
890890

891891
result.writelines(["\n", "\n", "def call(args):"])
892-
device = V.graph.get_current_device_or_throw()
893892
index = V.graph.get_current_device_or_throw().index
894893
with result.indent():
895894
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
@@ -924,7 +923,7 @@ def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer:
924923

925924
result.writeline("args = get_args()")
926925
result.writeline(
927-
f"ms = benchmarker.benchmark(call, fn_args=(args,), device={device.type},rep=40)"
926+
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)"
928927
)
929928
result.writeline(f"num_gb = {num_gb}")
930929
result.writeline("gb_per_s = num_gb / (ms / 1e3)")

torch/_inductor/ir.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5050,9 +5050,7 @@ def benchmark(self, *args: Any, out: torch.Tensor) -> float:
50505050
}
50515051
if config.profile_bandwidth_with_do_bench_using_profiling:
50525052
return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) # type: ignore[arg-type]
5053-
return benchmarker.benchmark(
5054-
algo, args, {"out": out}, device=None, **benchmark_configs
5055-
)
5053+
return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs)
50565054

50575055
def call_name(self) -> str:
50585056
raise NotImplementedError

torch/_inductor/runtime/benchmarking.py

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,15 @@ def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
9292

9393

9494
class Benchmarker:
95-
"""
96-
A device-agnostic benchmarking utility for measuring the runtime of
97-
inductor generated callables.
98-
"""
99-
10095
def __init__(self: Self) -> None:
10196
pass
10297

10398
@time_and_count
10499
def benchmark(
105100
self: Self,
106101
fn: Callable[..., Any],
107-
fn_args: Optional[tuple[Any, ...]] = None,
108-
fn_kwargs: Optional[dict[str, Any]] = None,
109-
device: Optional[Union[str, torch.device]] = None,
102+
fn_args: tuple[Any, ...],
103+
fn_kwargs: dict[str, Any],
110104
**kwargs: Any,
111105
) -> float:
112106
"""Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
@@ -115,61 +109,34 @@ def benchmark(
115109
device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
116110
`ValueError(...)` if we can't safely infer the device type of `fn`; for example,
117111
if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
118-
types are found. To bypass device inference, provide the device to the `device`
119-
parameter.
112+
types are found.
120113
121114
Arguments:
122115
- fn: The function to benchmark.
123116
- fn_args: The function's arguments.
124117
- fn_kwargs: The function's kwargs.
125118
126119
Keyword Arguments:
127-
- device: Which device to use for benchmarking. If not provided the device will be attempted
128-
to be inferred from `fn_args` and `fn_kwargs`.
129120
- **kwargs: The benchmarking implementation's kwargs.
130121
131122
Returns:
132123
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
133124
"""
134-
inferred_device: Optional[torch.device] = None
135-
if device is not None:
136-
inferred_device = (
137-
torch.device(device) if isinstance(device, str) else device
138-
)
139-
else:
140-
if fn_args is None and fn_kwargs is None:
125+
inferred_device = None
126+
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
127+
if not isinstance(arg_or_kwarg, torch.Tensor):
128+
continue
129+
if inferred_device is None:
130+
inferred_device = arg_or_kwarg.device
131+
elif arg_or_kwarg.device != inferred_device:
141132
raise ValueError(
142-
"`fn_args` and `fn_kwargs` cannot both be None if `device` is not provided."
133+
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
143134
)
144-
145-
fn_args = fn_args or tuple()
146-
fn_kwargs = fn_kwargs or {}
147-
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
148-
if not isinstance(arg_or_kwarg, torch.Tensor):
149-
continue
150-
if inferred_device is None:
151-
inferred_device = arg_or_kwarg.device
152-
elif arg_or_kwarg.device != inferred_device:
153-
raise ValueError(
154-
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
155-
)
156-
157135
if inferred_device is None:
158136
raise ValueError(
159-
"Can't safely infer the device type of `fn` with no device types"
160-
" in `fn_args` or `fn_kwargs` and `device` not explicitly provided!"
161-
" You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly."
137+
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
162138
)
163-
164-
fn_args = fn_args or tuple()
165-
fn_kwargs = fn_kwargs or {}
166-
167-
# No need to wrap if the callable takes no arguments
168-
if len(fn_args) == 0 and len(fn_kwargs) == 0:
169-
_callable = fn
170-
else:
171-
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
172-
139+
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
173140
if inferred_device == torch.device("cpu"):
174141
return self.benchmark_cpu(_callable, **kwargs)
175142
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -927,11 +927,11 @@ def kernel_call():
927927

928928
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
929929

930-
benchmark_kwargs = {"rep": 40} if self.device_props.type == "cuda" else {}
931-
return benchmarker.benchmark(
932-
fn=kernel_call,
933-
device=self.device_props.type,
934-
**benchmark_kwargs, # type: ignore[arg-type]
930+
if self.device_props.type == "cpu":
931+
return benchmarker.benchmark_cpu(kernel_call)
932+
933+
return benchmarker.benchmark_gpu(
934+
kernel_call, rep=40, is_vetted_benchmarking=True
935935
)
936936

937937
def copy_args_to_cpu_if_needed(self, *args, **kwargs):

torch/_inductor/scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3269,8 +3269,8 @@ def speedup_by_fusion(
32693269
device = node_list_1[0].get_device()
32703270
assert device
32713271

3272-
# don't support benchmark fusion for CPU C++ backend right now.
3273-
if device.type == "cpu" and config.cpu_backend != "triton":
3272+
# don't support benchmark fusion for CPU right now.
3273+
if device.type == "cpu":
32743274
return True
32753275

32763276
node_list_2 = node2.get_nodes()
@@ -5569,8 +5569,8 @@ def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool:
55695569
subkernel_nodes = nodes
55705570
device = subkernel_nodes[0].get_device()
55715571

5572-
# don't support benchmark fusion for CPU C++ backend right now.
5573-
if device is None or (device.type == "cpu" and config.cpu_backend != "triton"):
5572+
# don't support benchmark fusion for CPU right now.
5573+
if device is None or device.type == "cpu":
55745574
return True
55755575

55765576
from triton.compiler.errors import CompilationError

torch/_inductor/select_algorithm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2671,10 +2671,8 @@ def __call__(
26712671

26722672
# Templates selected with input_gen_fns require specific input data to avoid IMA
26732673
# Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
2674-
# TODO(jgong5): support multi-template on CPU C++ backend
2675-
if input_gen_fns is not None or (
2676-
layout.device.type == "cpu" and config.cpu_backend != "triton"
2677-
):
2674+
# TODO(jgong5): support multi-template on CPU
2675+
if input_gen_fns is not None or layout.device.type == "cpu":
26782676
return_multi_template = False
26792677

26802678
# TODO - assert that we have not mutating kernels here

torch/_inductor/wrapper_benchmark.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def benchmark_all_kernels(
9393
continue
9494

9595
triton_kernel = get_triton_kernel(kernel_mod)
96-
device_type = triton_kernel.device_props.type
9796
kernel_category = get_kernel_category(kernel_mod)
9897
args = kernel_mod.get_args()
9998
num_in_out_ptrs = len(
@@ -138,12 +137,7 @@ def get_info_str(
138137
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
139138
)
140139
else:
141-
ms = benchmarker.benchmark(
142-
kernel_mod.call,
143-
fn_args=(args,),
144-
device=device_type,
145-
rep=40,
146-
)
140+
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
147141
assert len(triton_kernel.launchers) == 1, (
148142
"Autotuner should have selected the best config"
149143
)

0 commit comments

Comments
 (0)