@@ -79,25 +79,13 @@ def get_ops_for_key(key):
7979batched_registrations = get_ops_for_key ('FuncTorchBatched' )
8080all_ops = get_ops_for_key (None )
8181
82- # Find all occurrences of things inside of STOP_DECOMPOSE(...) using regex
83- # Look in ../functorch/csrc/BatchRulesStopDecomposition.cpp
84- # Example:
85- # STOP_DECOMPOSE(sin); => sin
86- with open ('../functorch/csrc/BatchRulesStopDecomposition.cpp' ) as f :
87- content = f .read ()
88- stop_decomposition_regex = re .compile (r'STOP_DECOMPOSE\((.*)\);' )
89- stop_decomposition_matches = stop_decomposition_regex .findall (content )
90- stop_decomposition_matches = [m .strip () for m in stop_decomposition_matches ]
91- stop_decomposition_ops = set (stop_decomposition_matches )
92-
9382composite_ops = get_ops_for_key ('CompositeImplicitAutograd' )
94- decomposed_ops = composite_ops - stop_decomposition_ops
9583
9684
97- vmap_ops = ( batched_registrations - stop_decomposition_ops ) | ( composite_ops - stop_decomposition_ops )
85+ vmap_ops = batched_registrations
9886noncomposite_ops = all_ops - composite_ops
9987
100- ops = yaml .load (open ('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml' , 'r' ).read ())
88+ ops = yaml .load (open ('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml' , 'r' ).read (), Loader = yaml . CLoader )
10189
10290annotated_ops = {a .strip (): b .strip () for a ,b in list (csv .reader (open ('annotated_ops.txt' )))}
10391from collections import defaultdict
@@ -133,8 +121,6 @@ def annotate_ops(ops, is_unique):
133121 categorization ['inplace' ] += 1
134122 op ['meta' ] = 'inplace'
135123 continue
136- if 'slow_conv3d_backward.grad_input' in op ['full_name' ]:
137- import pdb ; pdb .set_trace ()
138124 if not is_unique and 'a!' in op ['func' ].lower ():
139125 categorization ['out' ] += 1
140126 op ['meta' ] = 'out'
0 commit comments