1
+ import contextlib
1
2
import functools
2
3
import logging
3
4
import threading
5
+ from email import contentmanager
4
6
5
7
from . import config
6
8
from . import convert_frame
@@ -24,26 +26,31 @@ def nothing():
24
26
pass
25
27
26
28
29
+ null_context = contextlib .nullcontext ()
30
+
27
31
unset = object ()
28
32
29
33
compile_lock = threading .Lock ()
30
34
31
35
32
36
class _TorchDynamoContext :
33
- def __init__ (self , callback , on_enter = nothing ):
37
+ def __init__ (self , callback , on_enter = nothing , extra_ctx = null_context ):
34
38
super ().__init__ ()
35
39
assert callable (callback ) or callback is False or callback is None
36
40
self .callback = callback
37
41
self .prior = unset
38
42
self .on_enter = on_enter
43
+ self .extra_ctx = extra_ctx
39
44
40
45
def __enter__ (self ):
41
46
self .on_enter ()
42
47
self .prior = set_eval_frame (self .callback )
48
+ self .extra_ctx .__enter__ ()
43
49
44
50
def __exit__ (self , exc_type , exc_val , exc_tb ):
45
51
set_eval_frame (self .prior )
46
52
self .prior = unset
53
+ self .extra_ctx .__exit__ (exc_type , exc_val , exc_tb )
47
54
48
55
def __call__ (self , fn ):
49
56
assert callable (fn )
@@ -69,8 +76,12 @@ def _fn(*args, **kwargs):
69
76
70
77
71
78
class OptimizeContext (_TorchDynamoContext ):
72
- def __init__ (self , callback ):
73
- super ().__init__ (callback = callback , on_enter = install_generation_tagging_new )
79
+ def __init__ (self , callback , extra_ctx ):
80
+ super ().__init__ (
81
+ callback = callback ,
82
+ on_enter = install_generation_tagging_new ,
83
+ extra_ctx = extra_ctx ,
84
+ )
74
85
75
86
76
87
class RunOnlyContext (_TorchDynamoContext ):
@@ -107,8 +118,8 @@ def catch_errors(frame, cache_size):
107
118
return catch_errors
108
119
109
120
110
- def _optimize_catch_errors (compile_fn ):
111
- return OptimizeContext (catch_errors_wrapper (compile_fn ))
121
+ def _optimize_catch_errors (compile_fn , extra_ctx = null_context ):
122
+ return OptimizeContext (catch_errors_wrapper (compile_fn ), extra_ctx = extra_ctx )
112
123
113
124
114
125
def optimize (backend , nopython = False ):
@@ -136,16 +147,23 @@ def toy_example(a, b):
136
147
with torchdynamo.optimize(my_compiler):
137
148
...
138
149
"""
150
+
151
+ extra_ctx = null_context
152
+ if hasattr (backend , "extra_ctx" ):
153
+ extra_ctx = getattr (backend , "extra_ctx" )
154
+
139
155
if nopython :
140
- return optimize_assert (backend )
141
- return _optimize_catch_errors (convert_frame .convert_frame (backend ))
156
+ return optimize_assert (backend , extra_ctx )
157
+ return _optimize_catch_errors (convert_frame .convert_frame (backend ), extra_ctx )
142
158
143
159
144
- def optimize_assert (backend ):
160
+ def optimize_assert (backend , extra_ctx = null_context ):
145
161
"""
146
162
The same as `torchdynamo.optimize(backend, nopython=True)`
147
163
"""
148
- return _optimize_catch_errors (convert_frame .convert_frame_assert (backend ))
164
+ return _optimize_catch_errors (
165
+ convert_frame .convert_frame_assert (backend ), extra_ctx
166
+ )
149
167
150
168
151
169
def run (fn = None ):
0 commit comments