1+ import contextlib
12import functools
23import logging
34import threading
5+ from email import contentmanager
46
57from . import config
68from . import convert_frame
@@ -24,26 +26,31 @@ def nothing():
2426 pass
2527
2628
29+ null_context = contextlib .nullcontext ()
30+
2731unset = object ()
2832
2933compile_lock = threading .Lock ()
3034
3135
3236class _TorchDynamoContext :
33- def __init__ (self , callback , on_enter = nothing ):
37+ def __init__ (self , callback , on_enter = nothing , extra_ctx = null_context ):
3438 super ().__init__ ()
3539 assert callable (callback ) or callback is False or callback is None
3640 self .callback = callback
3741 self .prior = unset
3842 self .on_enter = on_enter
43+ self .extra_ctx = extra_ctx
3944
4045 def __enter__ (self ):
4146 self .on_enter ()
4247 self .prior = set_eval_frame (self .callback )
48+ self .extra_ctx .__enter__ ()
4349
4450 def __exit__ (self , exc_type , exc_val , exc_tb ):
4551 set_eval_frame (self .prior )
4652 self .prior = unset
53+ self .extra_ctx .__exit__ (exc_type , exc_val , exc_tb )
4754
4855 def __call__ (self , fn ):
4956 assert callable (fn )
@@ -69,8 +76,12 @@ def _fn(*args, **kwargs):
6976
7077
7178class OptimizeContext (_TorchDynamoContext ):
72- def __init__ (self , callback ):
73- super ().__init__ (callback = callback , on_enter = install_generation_tagging_new )
79+ def __init__ (self , callback , extra_ctx ):
80+ super ().__init__ (
81+ callback = callback ,
82+ on_enter = install_generation_tagging_new ,
83+ extra_ctx = extra_ctx ,
84+ )
7485
7586
7687class RunOnlyContext (_TorchDynamoContext ):
@@ -107,8 +118,8 @@ def catch_errors(frame, cache_size):
107118 return catch_errors
108119
109120
110- def _optimize_catch_errors (compile_fn ):
111- return OptimizeContext (catch_errors_wrapper (compile_fn ))
121+ def _optimize_catch_errors (compile_fn , extra_ctx = null_context ):
122+ return OptimizeContext (catch_errors_wrapper (compile_fn ), extra_ctx = extra_ctx )
112123
113124
114125def optimize (backend , nopython = False ):
@@ -136,16 +147,23 @@ def toy_example(a, b):
136147 with torchdynamo.optimize(my_compiler):
137148 ...
138149 """
150+
151+ extra_ctx = null_context
152+ if hasattr (backend , "extra_ctx" ):
153+ extra_ctx = getattr (backend , "extra_ctx" )
154+
139155 if nopython :
140- return optimize_assert (backend )
141- return _optimize_catch_errors (convert_frame .convert_frame (backend ))
156+ return optimize_assert (backend , extra_ctx )
157+ return _optimize_catch_errors (convert_frame .convert_frame (backend ), extra_ctx )
142158
143159
144- def optimize_assert (backend ):
160+ def optimize_assert (backend , extra_ctx = null_context ):
145161 """
146162 The same as `torchdynamo.optimize(backend, nopython=True)`
147163 """
148- return _optimize_catch_errors (convert_frame .convert_frame_assert (backend ))
164+ return _optimize_catch_errors (
165+ convert_frame .convert_frame_assert (backend ), extra_ctx
166+ )
149167
150168
151169def run (fn = None ):
0 commit comments