1+ import contextlib
12import functools
23import logging
34import threading
@@ -24,26 +25,31 @@ def nothing():
2425 pass
2526
2627
28+ null_context = contextlib .nullcontext ()
29+
2730unset = object ()
2831
2932compile_lock = threading .Lock ()
3033
3134
3235class _TorchDynamoContext :
33- def __init__ (self , callback , on_enter = nothing ):
36+ def __init__ (self , callback , on_enter = nothing , extra_ctx = null_context ):
3437 super ().__init__ ()
3538 assert callable (callback ) or callback is False or callback is None
3639 self .callback = callback
3740 self .prior = unset
3841 self .on_enter = on_enter
42+ self .extra_ctx = extra_ctx
3943
4044 def __enter__ (self ):
4145 self .on_enter ()
4246 self .prior = set_eval_frame (self .callback )
47+ self .extra_ctx .__enter__ ()
4348
4449 def __exit__ (self , exc_type , exc_val , exc_tb ):
4550 set_eval_frame (self .prior )
4651 self .prior = unset
52+ self .extra_ctx .__exit__ (exc_type , exc_val , exc_tb )
4753
4854 def __call__ (self , fn ):
4955 assert callable (fn )
@@ -69,8 +75,12 @@ def _fn(*args, **kwargs):
6975
7076
7177class OptimizeContext (_TorchDynamoContext ):
72- def __init__ (self , callback ):
73- super ().__init__ (callback = callback , on_enter = install_generation_tagging_new )
78+ def __init__ (self , callback , extra_ctx ):
79+ super ().__init__ (
80+ callback = callback ,
81+ on_enter = install_generation_tagging_new ,
82+ extra_ctx = extra_ctx ,
83+ )
7484
7585
7686class RunOnlyContext (_TorchDynamoContext ):
@@ -107,8 +117,8 @@ def catch_errors(frame, cache_size):
107117 return catch_errors
108118
109119
110- def _optimize_catch_errors (compile_fn ):
111- return OptimizeContext (catch_errors_wrapper (compile_fn ))
120+ def _optimize_catch_errors (compile_fn , extra_ctx = null_context ):
121+ return OptimizeContext (catch_errors_wrapper (compile_fn ), extra_ctx = extra_ctx )
112122
113123
114124def optimize (backend , nopython = False ):
@@ -136,16 +146,23 @@ def toy_example(a, b):
136146 with torchdynamo.optimize(my_compiler):
137147 ...
138148 """
149+
150+ extra_ctx = null_context
151+ if hasattr (backend , "extra_ctx" ):
152+ extra_ctx = getattr (backend , "extra_ctx" )
153+
139154 if nopython :
140- return optimize_assert (backend )
141- return _optimize_catch_errors (convert_frame .convert_frame (backend ))
155+ return optimize_assert (backend , extra_ctx )
156+ return _optimize_catch_errors (convert_frame .convert_frame (backend ), extra_ctx )
142157
143158
144- def optimize_assert (backend ):
159+ def optimize_assert (backend , extra_ctx = null_context ):
145160 """
146161 The same as `torchdynamo.optimize(backend, nopython=True)`
147162 """
148- return _optimize_catch_errors (convert_frame .convert_frame_assert (backend ))
163+ return _optimize_catch_errors (
164+ convert_frame .convert_frame_assert (backend ), extra_ctx
165+ )
149166
150167
151168def run (fn = None ):
0 commit comments