1111from functorch .experimental import functionalize
1212from . import config
1313from .decompositions import register_decomposition
14- from .partitioners import default_partition
14+ from .partitioners import default_partition , _get_saved_values , _extract_fwd_bwd_modules
1515from .named_members_polyfill import _named_parameters , _named_buffers
1616from typing import Callable , List , Dict , Any , Tuple , Optional
1717from functools import wraps
@@ -70,7 +70,7 @@ def preserve_rng_state():
7070
7171def create_joint_forward_backward (fn ):
7272 def joint_forward_backward (
73- primals : List [Any ], tangents : List [Any ]
73+ primals : List [Any ], cotangents : List [Any ]
7474 ) -> Tuple [List [Any ], List [Any ]]:
7575 # Call the forward pass
7676 outs = fn (* primals )
@@ -84,21 +84,21 @@ def joint_forward_backward(
8484 grad_primals .append (p )
8585
8686 # Get the outputs that need gradients
87- assert len (tangents ) == len (outs )
87+ assert len (cotangents ) == len (outs )
8888 needed_outs = []
89- needed_tangents = []
90- for out , tangent in zip (outs , tangents ):
89+ needed_cotangents = []
90+ for out , cotangent in zip (outs , cotangents ):
9191 if isinstance (out , Tensor ) and out .requires_grad :
9292 needed_outs .append (out )
93- needed_tangents .append (tangent )
93+ needed_cotangents .append (cotangent )
9494 backward_out = []
9595 # Call the backwards pass
9696 if grad_primals :
9797 backward_out = torch .autograd .grad (
9898 needed_outs ,
9999 grad_primals ,
100- grad_outputs = needed_tangents ,
101- allow_unused = True ,
100+ grad_outputs = needed_cotangents ,
101+ allow_unused = True
102102 )
103103 backward_out_iter = iter (backward_out )
104104 return outs , [
@@ -152,16 +152,21 @@ def create_aot_autograd_function(
152152 if decompositions is None :
153153 decompositions = {}
154154 joint_forward_backward = create_joint_forward_backward (flat_fn )
155-
155+ # create_joint_forward_backward takes inputs and cotangents as inps
156+ # inps: inputs, cotangents: flat_grad_outs
157+ j_b = None
156158 compiled_fw = None
157- compiled_bw = None
159+ bw_modules = []
158160 num_outs = None
161+ saved_value_names = None
162+ aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
159163
160164 class CompiledFunction (torch .autograd .Function ):
161165 @staticmethod
162166 @disable_torchdynamo
163167 def forward (ctx , * flat_tensor_args ):
164- nonlocal compiled_fw , compiled_bw , num_outs
168+ # ctx.set_materialize_grads(False)
169+ nonlocal compiled_fw , num_outs , saved_value_names , aot_decompositions , j_b
165170 # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
166171 # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
167172 old_jit_autocast_flag = torch ._C ._jit_set_autocast_mode (False )
@@ -184,8 +189,9 @@ def forward(ctx, *flat_tensor_args):
184189 num_outs = 1
185190
186191 joint_inputs = (flat_tensor_args , out )
187- aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
188192 with torch .set_grad_enabled (grad_state ):
193+ # This means the forward and backward graphs are created based on the input fn
194+ # However we need to take in grad_out for the saved intermediates as well.
189195 fx_g = make_fx (joint_forward_backward , aot_decompositions )(
190196 * joint_inputs
191197 )
@@ -196,33 +202,79 @@ def forward(ctx, *flat_tensor_args):
196202 def fake_fn (primals , tangents ):
197203 return fx_g (primals , tangents )
198204 fx_g = make_fx (functionalize (fake_fn ))(* joint_inputs )
199- fw_module , bw_module = partition_fn (fx_g , joint_inputs )
200- # print(fw_module.code, bw_module.code)
201-
205+ fw_module , bw_module , saved_value_nodes = partition_fn (fx_g , joint_inputs )
206+ saved_value_names = [node .name for node in saved_value_nodes ]
202207 compiled_fw = fw_compiler (fw_module , flat_tensor_args )
203208 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
204-
205- bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
206- compiled_bw = bw_compiler (bw_module , bw_args )
209+ j_b = create_joint_forward_backward (fw_module )
207210 else :
208211 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
212+ ctx .num_intermediate = len (fw_outs [num_outs :])
213+ to_be_saved = fw_outs [num_outs :] + list (flat_tensor_args )
214+ ctx .save_for_backward (* to_be_saved )
209215 torch ._C ._jit_set_autocast_mode (old_jit_autocast_flag )
210- ctx .save_for_backward (* fw_outs [num_outs :])
211- return tuple (fw_outs [0 :num_outs ])
216+ return tuple (fw_outs )
212217
213218 @staticmethod
214219 @disable_torchdynamo
215- def backward (ctx , * flat_args ):
220+ def backward (ctx , * flat_grad_outs ):
216221 # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
217222 # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
218223 old_jit_autocast_flag = torch ._C ._jit_set_autocast_mode (False )
219- contiguous_args = [t .contiguous () for t in flat_args ]
220- # contiguous_args = [t for t in flat_args]
221- out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
224+ nonlocal bw_modules , saved_value_names , num_outs , aot_decompositions , j_b
225+ with preserve_rng_state ():
226+ intermediates = ctx .saved_tensors [:ctx .num_intermediate ]
227+ flat_tensor_args = ctx .saved_tensors [ctx .num_intermediate :]
228+ flat_tensor_args = pytree .tree_map (
229+ lambda x : x .detach ().requires_grad_ (x .requires_grad )
230+ if isinstance (x , Tensor ) else x , flat_tensor_args
231+ )
232+ inp_grad_outs = pytree .tree_map (
233+ lambda x : x .detach () if isinstance (x , Tensor ) else x , flat_grad_outs
234+ )
235+ # inp_grad_outs = flat_grad_outs
236+ with torch .set_grad_enabled (grad_state ):
237+ fx_g_b = make_fx (j_b , aot_decompositions )(flat_tensor_args , inp_grad_outs )
238+ if config .use_functionalize :
239+ # Functionalize the foward backward graph. First create a
240+ # fake fn to make functionalize happy
241+ def fake_fn (primals , tangents ):
242+ return fx_g (primals , tangents )
243+ fx_g = make_fx (functionalize (fake_fn ))(flat_tensor_args , inp_grad_outs )
244+ saved_value_nodes = _get_saved_values (fx_g_b , saved_value_names )
245+ assert len (saved_value_nodes ) <= len (saved_value_names )
246+ fw_module_b , bw_module_b , saved_values_new = _extract_fwd_bwd_modules (fx_g_b , saved_value_nodes )
247+ if len (saved_values_new ) != len (saved_value_names ):
248+ new_intermediates = []
249+ # Forward saves more intermediates than needed
250+ assert len (saved_values_new ) < len (saved_value_names )
251+ j = 0
252+ for node in saved_values_new :
253+ while node .name != saved_value_names [j ]:
254+ j += 1
255+ new_intermediates .append (intermediates [j ])
256+ j += 1
257+ intermediates = new_intermediates
258+
259+ # This is needed because aot function caching uses function id right now
260+ bw_module_fn = None
261+ for elem in bw_modules :
262+ if elem .code == bw_module_b .code :
263+ bw_module_fn = elem
264+ break
265+ if bw_module_fn is None :
266+ bw_modules .append (bw_module_b )
267+ bw_module_fn = bw_module_b
268+
269+ f = aot_function (bw_module_fn , bw_compiler , bw_compiler , partition_fn , aot_decompositions )
270+ out = f (* intermediates , * flat_grad_outs )
222271 torch ._C ._jit_set_autocast_mode (old_jit_autocast_flag )
223- return tuple (out )
272+ return tuple (normalize_as_list ( out ) )
224273
225- return CompiledFunction
274+ def return_fn (* args , ** kwargs ):
275+ out = CompiledFunction .apply (* args , ** kwargs )
276+ return out [0 :num_outs ]
277+ return return_fn
226278
227279
228280class _CompileCache (CompileCache ):
@@ -312,7 +364,7 @@ def rearrange(tensor_args, static_args, static_argnums):
312364 return args
313365
314366
315- KNOWN_TYPES = [torch .Tensor , int , str , float , bool ]
367+ KNOWN_TYPES = [torch .Tensor , int , str , float , bool , None ]
316368
317369
318370def aot_function (
@@ -448,7 +500,6 @@ def returned_function(*args, **kwargs):
448500 hasher_type ,
449501 * flat_args_for_cache ,
450502 )
451-
452503 # Compile the function and save it in the cache
453504 if cached_res is None :
454505 # Save the args_spec for flat_tensor_args to unflatten while tracing
@@ -473,7 +524,7 @@ def flat_fn(*flat_tensor_args):
473524 for i in flat_out :
474525 is_known_type = False
475526 for j in KNOWN_TYPES :
476- if isinstance (i , j ):
527+ if j is None or isinstance (i , j ):
477528 is_known_type = True
478529 break
479530 if not is_known_type :
@@ -495,7 +546,7 @@ def flat_fn(*flat_tensor_args):
495546 partition_fn ,
496547 decompositions ,
497548 grad_state = torch .is_grad_enabled (),
498- ). apply
549+ )
499550 cached_res = (compiled_fn , out_spec )
500551
501552 # Save the compiled_fn in the cache
@@ -635,7 +686,7 @@ def aot_function_simplified(
635686 partition_fn ,
636687 decompositions ,
637688 grad_state = torch .is_grad_enabled (),
638- ). apply
689+ )
639690
640691 return compiled_fn
641692
0 commit comments