From 41745bac2e91a888f0116e1d1cc3e10c42122be4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Fri, 3 Oct 2025 16:52:22 +0000 Subject: [PATCH] Do the "preprocessing" right for PyTorch compiled grouped GEMM --- .../operators/grouped_gemm/operator.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tritonbench/operators/grouped_gemm/operator.py b/tritonbench/operators/grouped_gemm/operator.py index 27e93d3f..d6511f7b 100644 --- a/tritonbench/operators/grouped_gemm/operator.py +++ b/tritonbench/operators/grouped_gemm/operator.py @@ -108,17 +108,18 @@ def _inner(): # TODO: Does not work on hip @register_benchmark(enabled=is_cuda()) def preprocessed_pt2_triton_grouped_mm(self, group_A, group_B): - def _inner(): - torch._dynamo.reset() + torch._dynamo.reset() - with inductor_config.patch( - max_autotune=True, - max_autotune_gemm_backends="TRITON", - autotune_fallback_to_aten=False, - ): - A_packed, B_shared, offs = self.list_input_to_jagged(group_A, group_B) - compiled = torch.compile(torch._grouped_mm, dynamic=False) - return compiled(A_packed, B_shared, offs=offs, bias=None) + with inductor_config.patch( + max_autotune=True, + max_autotune_gemm_backends="TRITON", + autotune_fallback_to_aten=False, + ): + A_packed, B_shared, offs = self.list_input_to_jagged(group_A, group_B) + compiled = torch.compile(torch._grouped_mm, dynamic=False) + + def _inner(): + return compiled(A_packed, B_shared, offs=offs, bias=None) return _inner