Skip to content

Commit 89986a7

Browse files
committed
Pass backend-related ctx to TorchDynamo Optimize Context
1 parent bcb169d commit 89986a7

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

torchdynamo/eval_frame.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import contextlib
12
import functools
23
import logging
34
import threading
5+
from email import contentmanager
46

57
from . import config
68
from . import convert_frame
@@ -24,26 +26,31 @@ def nothing():
2426
pass
2527

2628

29+
null_context = contextlib.nullcontext()
30+
2731
unset = object()
2832

2933
compile_lock = threading.Lock()
3034

3135

3236
class _TorchDynamoContext:
33-
def __init__(self, callback, on_enter=nothing):
37+
def __init__(self, callback, on_enter=nothing, extra_ctx=null_context):
3438
super().__init__()
3539
assert callable(callback) or callback is False or callback is None
3640
self.callback = callback
3741
self.prior = unset
3842
self.on_enter = on_enter
43+
self.extra_ctx = extra_ctx
3944

4045
def __enter__(self):
4146
self.on_enter()
4247
self.prior = set_eval_frame(self.callback)
48+
self.extra_ctx.__enter__()
4349

4450
def __exit__(self, exc_type, exc_val, exc_tb):
4551
set_eval_frame(self.prior)
4652
self.prior = unset
53+
self.extra_ctx.__exit__(exc_type, exc_val, exc_tb)
4754

4855
def __call__(self, fn):
4956
assert callable(fn)
@@ -69,8 +76,12 @@ def _fn(*args, **kwargs):
6976

7077

7178
class OptimizeContext(_TorchDynamoContext):
72-
def __init__(self, callback):
73-
super().__init__(callback=callback, on_enter=install_generation_tagging_new)
79+
def __init__(self, callback, extra_ctx):
80+
super().__init__(
81+
callback=callback,
82+
on_enter=install_generation_tagging_new,
83+
extra_ctx=extra_ctx,
84+
)
7485

7586

7687
class RunOnlyContext(_TorchDynamoContext):
@@ -107,8 +118,8 @@ def catch_errors(frame, cache_size):
107118
return catch_errors
108119

109120

110-
def _optimize_catch_errors(compile_fn):
111-
return OptimizeContext(catch_errors_wrapper(compile_fn))
121+
def _optimize_catch_errors(compile_fn, extra_ctx=null_context):
122+
return OptimizeContext(catch_errors_wrapper(compile_fn), extra_ctx=extra_ctx)
112123

113124

114125
def optimize(backend, nopython=False):
@@ -136,16 +147,23 @@ def toy_example(a, b):
136147
with torchdynamo.optimize(my_compiler):
137148
...
138149
"""
150+
151+
extra_ctx = null_context
152+
if hasattr(backend, "extra_ctx"):
153+
extra_ctx = getattr(backend, "extra_ctx")
154+
139155
if nopython:
140-
return optimize_assert(backend)
141-
return _optimize_catch_errors(convert_frame.convert_frame(backend))
156+
return optimize_assert(backend, extra_ctx)
157+
return _optimize_catch_errors(convert_frame.convert_frame(backend), extra_ctx)
142158

143159

144-
def optimize_assert(backend):
160+
def optimize_assert(backend, extra_ctx=null_context):
145161
"""
146162
The same as `torchdynamo.optimize(backend, nopython=True)`
147163
"""
148-
return _optimize_catch_errors(convert_frame.convert_frame_assert(backend))
164+
return _optimize_catch_errors(
165+
convert_frame.convert_frame_assert(backend), extra_ctx
166+
)
149167

150168

151169
def run(fn=None):

torchdynamo/optimizations/training.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,14 @@ def candidate(self):
143143
return BACKENDS["aot_autograd"](self.gm, self.example_inputs)
144144

145145

146-
aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusion.compile_fn
146+
class AOTAutogradMemoryEfficientFusionWithContext:
147+
"""Pass nvfuser context to TorchDynamo"""
148+
149+
def __init__(self):
150+
self.extra_ctx = torch.jit.fuser("fuser2")
151+
152+
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
153+
return AOTAutogradMemoryEfficientFusion.compile_fn(gm, example_inputs)
154+
155+
156+
aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusionWithContext()

0 commit comments

Comments
 (0)