1+ import contextlib
12import functools
23import logging
34import threading
@@ -24,26 +25,32 @@ def nothing():
2425 pass
2526
2627
28+ null_context = contextlib .nullcontext
29+
2730unset = object ()
2831
2932compile_lock = threading .Lock ()
3033
3134
3235class _TorchDynamoContext :
33- def __init__ (self , callback , on_enter = nothing ):
36+ def __init__ (self , callback , on_enter = nothing , backend_ctx_ctor = null_context ):
3437 super ().__init__ ()
3538 assert callable (callback ) or callback is False or callback is None
3639 self .callback = callback
3740 self .prior = unset
3841 self .on_enter = on_enter
42+ self .extra_ctx_ctor = backend_ctx_ctor
3943
4044 def __enter__ (self ):
4145 self .on_enter ()
4246 self .prior = set_eval_frame (self .callback )
47+ self .backend_ctx = self .extra_ctx_ctor ()
48+ self .backend_ctx .__enter__ ()
4349
4450 def __exit__ (self , exc_type , exc_val , exc_tb ):
4551 set_eval_frame (self .prior )
4652 self .prior = unset
53+ self .backend_ctx .__exit__ (exc_type , exc_val , exc_tb )
4754
4855 def __call__ (self , fn ):
4956 assert callable (fn )
@@ -69,8 +76,12 @@ def _fn(*args, **kwargs):
6976
7077
7178class OptimizeContext (_TorchDynamoContext ):
72- def __init__ (self , callback ):
73- super ().__init__ (callback = callback , on_enter = install_generation_tagging_new )
79+ def __init__ (self , callback , backend_ctx_ctor ):
80+ super ().__init__ (
81+ callback = callback ,
82+ on_enter = install_generation_tagging_new ,
83+ backend_ctx_ctor = backend_ctx_ctor ,
84+ )
7485
7586
7687class RunOnlyContext (_TorchDynamoContext ):
@@ -107,8 +118,10 @@ def catch_errors(frame, cache_size):
107118 return catch_errors
108119
109120
110- def _optimize_catch_errors (compile_fn ):
111- return OptimizeContext (catch_errors_wrapper (compile_fn ))
121+ def _optimize_catch_errors (compile_fn , backend_ctx_ctor = null_context ):
122+ return OptimizeContext (
123+ catch_errors_wrapper (compile_fn ), backend_ctx_ctor = backend_ctx_ctor
124+ )
112125
113126
114127def optimize (backend , nopython = False ):
@@ -117,10 +130,13 @@ def optimize(backend, nopython=False):
117130 backend() to optimize extracted graphs.
118131
119132 Args:
120- backend: One of two things:
121- - Either, a function taking a torch.fx.GraphModule and
133+ backend: One of the two things:
134+ - Either, a function/callable taking a torch.fx.GraphModule and
122135 example_inputs and returning a python callable that runs the
123136 graph faster.
137+ One can also provide additional context for the backend, like
138+ torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
139+ See AOTAutogradMemoryEfficientFusionWithContext for the usage.
124140 - Or, a string backend name in `torchdynamo.list_backends()`
125141 nopython: If True, graph breaks will be errors and there will
126142 be a single whole-program graph.
@@ -136,16 +152,25 @@ def toy_example(a, b):
136152 with torchdynamo.optimize(my_compiler):
137153 ...
138154 """
155+
156+ backend_ctx_ctor = null_context
157+ if hasattr (backend , "backend_ctx_ctor" ):
158+ backend_ctx_ctor = getattr (backend , "backend_ctx_ctor" )
159+
139160 if nopython :
140- return optimize_assert (backend )
141- return _optimize_catch_errors (convert_frame .convert_frame (backend ))
161+ return optimize_assert (backend , backend_ctx_ctor )
162+ return _optimize_catch_errors (
163+ convert_frame .convert_frame (backend ), backend_ctx_ctor
164+ )
142165
143166
144- def optimize_assert (backend ):
167+ def optimize_assert (backend , backend_ctx_ctor = null_context ):
145168 """
146169 The same as `torchdynamo.optimize(backend, nopython=True)`
147170 """
148- return _optimize_catch_errors (convert_frame .convert_frame_assert (backend ))
171+ return _optimize_catch_errors (
172+ convert_frame .convert_frame_assert (backend ), backend_ctx_ctor
173+ )
149174
150175
151176def run (fn = None ):
0 commit comments