Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit ba8dc15

Browse files
committed
Pass backend-related ctx to TorchDynamo Optimize Context
1 parent bcb169d commit ba8dc15

2 files changed

Lines changed: 37 additions & 10 deletions

File tree

torchdynamo/eval_frame.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import functools
23
import logging
34
import threading
@@ -24,26 +25,31 @@ def nothing():
2425
pass
2526

2627

28+
null_context = contextlib.nullcontext()
29+
2730
unset = object()
2831

2932
compile_lock = threading.Lock()
3033

3134

3235
class _TorchDynamoContext:
33-
def __init__(self, callback, on_enter=nothing):
36+
def __init__(self, callback, on_enter=nothing, extra_ctx=null_context):
3437
super().__init__()
3538
assert callable(callback) or callback is False or callback is None
3639
self.callback = callback
3740
self.prior = unset
3841
self.on_enter = on_enter
42+
self.extra_ctx = extra_ctx
3943

4044
def __enter__(self):
4145
self.on_enter()
4246
self.prior = set_eval_frame(self.callback)
47+
self.extra_ctx.__enter__()
4348

4449
def __exit__(self, exc_type, exc_val, exc_tb):
4550
set_eval_frame(self.prior)
4651
self.prior = unset
52+
self.extra_ctx.__exit__(exc_type, exc_val, exc_tb)
4753

4854
def __call__(self, fn):
4955
assert callable(fn)
@@ -69,8 +75,12 @@ def _fn(*args, **kwargs):
6975

7076

7177
class OptimizeContext(_TorchDynamoContext):
72-
def __init__(self, callback):
73-
super().__init__(callback=callback, on_enter=install_generation_tagging_new)
78+
def __init__(self, callback, extra_ctx):
79+
super().__init__(
80+
callback=callback,
81+
on_enter=install_generation_tagging_new,
82+
extra_ctx=extra_ctx,
83+
)
7484

7585

7686
class RunOnlyContext(_TorchDynamoContext):
@@ -107,8 +117,8 @@ def catch_errors(frame, cache_size):
107117
return catch_errors
108118

109119

110-
def _optimize_catch_errors(compile_fn):
111-
return OptimizeContext(catch_errors_wrapper(compile_fn))
120+
def _optimize_catch_errors(compile_fn, extra_ctx=null_context):
121+
return OptimizeContext(catch_errors_wrapper(compile_fn), extra_ctx=extra_ctx)
112122

113123

114124
def optimize(backend, nopython=False):
@@ -136,16 +146,23 @@ def toy_example(a, b):
136146
with torchdynamo.optimize(my_compiler):
137147
...
138148
"""
149+
150+
extra_ctx = null_context
151+
if hasattr(backend, "extra_ctx"):
152+
extra_ctx = getattr(backend, "extra_ctx")
153+
139154
if nopython:
140-
return optimize_assert(backend)
141-
return _optimize_catch_errors(convert_frame.convert_frame(backend))
155+
return optimize_assert(backend, extra_ctx)
156+
return _optimize_catch_errors(convert_frame.convert_frame(backend), extra_ctx)
142157

143158

144-
def optimize_assert(backend):
159+
def optimize_assert(backend, extra_ctx=null_context):
145160
"""
146161
The same as `torchdynamo.optimize(backend, nopython=True)`
147162
"""
148-
return _optimize_catch_errors(convert_frame.convert_frame_assert(backend))
163+
return _optimize_catch_errors(
164+
convert_frame.convert_frame_assert(backend), extra_ctx
165+
)
149166

150167

151168
def run(fn=None):

torchdynamo/optimizations/training.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,14 @@ def candidate(self):
143143
return BACKENDS["aot_autograd"](self.gm, self.example_inputs)
144144

145145

146-
aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusion.compile_fn
146+
class AOTAutogradMemoryEfficientFusionWithContext:
147+
"""Pass nvfuser context to TorchDynamo"""
148+
149+
def __init__(self):
150+
self.extra_ctx = torch.jit.fuser("fuser2")
151+
152+
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
153+
return AOTAutogradMemoryEfficientFusion.compile_fn(gm, example_inputs)
154+
155+
156+
aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusionWithContext()

0 commit comments

Comments
 (0)