File tree Expand file tree Collapse file tree 1 file changed +11
-10
lines changed
tritonbench/operators/grouped_gemm Expand file tree Collapse file tree 1 file changed +11
-10
lines changed Original file line number Diff line number Diff line change @@ -108,17 +108,18 @@ def _inner():
108108 # TODO: Does not work on hip
109109 @register_benchmark (enabled = is_cuda ())
110110 def preprocessed_pt2_triton_grouped_mm (self , group_A , group_B ):
111- def _inner ():
112- torch ._dynamo .reset ()
111+ torch ._dynamo .reset ()
113112
114- with inductor_config .patch (
115- max_autotune = True ,
116- max_autotune_gemm_backends = "TRITON" ,
117- autotune_fallback_to_aten = False ,
118- ):
119- A_packed , B_shared , offs = self .list_input_to_jagged (group_A , group_B )
120- compiled = torch .compile (torch ._grouped_mm , dynamic = False )
121- return compiled (A_packed , B_shared , offs = offs , bias = None )
113+ with inductor_config .patch (
114+ max_autotune = True ,
115+ max_autotune_gemm_backends = "TRITON" ,
116+ autotune_fallback_to_aten = False ,
117+ ):
118+ A_packed , B_shared , offs = self .list_input_to_jagged (group_A , group_B )
119+ compiled = torch .compile (torch ._grouped_mm , dynamic = False )
120+
121+ def _inner ():
122+ return compiled (A_packed , B_shared , offs = offs , bias = None )
122123
123124 return _inner
124125
You can’t perform that action at this time.
0 commit comments