@@ -166,6 +166,60 @@ def _do_bench_inductor(fn, warmup, rep, return_mode="all", grad_to_none=None):
166166 return _summarize_statistics (times , quantiles = None , return_mode = return_mode )
167167
168168
169+ def _do_bench_cudagraph_with_cache_clear (
170+ fn , rep = 20 , grad_to_none = None , quantiles = None , return_mode = "mean"
171+ ):
172+ """Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing."""
173+ assert return_mode in ["min" , "max" , "mean" , "median" , "all" ]
174+
175+ cache = triton .runtime .driver .active .get_empty_cache_for_benchmark ()
176+
177+ with torch .cuda .stream (torch .cuda .Stream ()):
178+ cache .zero_ ()
179+ fn ()
180+ if grad_to_none is not None :
181+ for x in grad_to_none :
182+ x .detach_ ()
183+ x .requires_grad_ (True )
184+ x .grad = None
185+
186+ start_event = torch .cuda .Event (enable_timing = True )
187+ end_event = torch .cuda .Event (enable_timing = True )
188+ start_event .record ()
189+ for _ in range (5 ):
190+ cache .zero_ ()
191+ fn ()
192+ end_event .record ()
193+ torch .cuda .synchronize ()
194+ estimate_ms = start_event .elapsed_time (end_event ) / 5
195+
196+ n_repeat = 1000 if estimate_ms == 0 else max (1 , int (rep / estimate_ms ))
197+
198+ g = torch .cuda .CUDAGraph ()
199+ with torch .cuda .graph (g ):
200+ for _ in range (n_repeat ):
201+ if grad_to_none is not None :
202+ for x in grad_to_none :
203+ x .grad = None
204+ cache .zero_ ()
205+ fn ()
206+ torch .cuda .synchronize ()
207+
208+ ret = []
209+ n_retries = 10
210+ for _ in range (n_retries ):
211+ start_event = torch .cuda .Event (enable_timing = True )
212+ end_event = torch .cuda .Event (enable_timing = True )
213+ start_event .record ()
214+ g .replay ()
215+ end_event .record ()
216+ torch .cuda .synchronize ()
217+ ret .append (start_event .elapsed_time (end_event ) / n_repeat )
218+
219+ times = torch .tensor (ret , dtype = torch .float )
220+ return _summarize_statistics (times , quantiles , return_mode )
221+
222+
169223def _do_bench_profiler (
170224 fn , warmup , rep , return_mode = "all" , grad_to_none = None , use_cudagraph = False
171225):
@@ -383,7 +437,7 @@ def do_bench_wrapper(
383437 if latency_measure_mode == "profiler" :
384438 bench_fn = partial (_do_bench_profiler , warmup = 1 , use_cudagraph = True )
385439 else :
386- bench_fn = triton . testing . do_bench_cudagraph
440+ bench_fn = _do_bench_cudagraph_with_cache_clear
387441
388442 return Latency (
389443 times = bench_fn (
0 commit comments