-
Notifications
You must be signed in to change notification settings - Fork 104
Separate forward and backwad compilation and support higher order derivatives for aot_function #856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/anjali411/1/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
…tion" [ghstack-poisoned]
…tion" Test Plan: Existing tests should pass [ghstack-poisoned]
functorch/_src/aot_autograd.py
Outdated
@@ -140,12 +140,15 @@ def create_aot_autograd_function( | |||
compiled_fw = None | |||
compiled_bw = None | |||
num_outs = None | |||
joint_inputs = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to save these tensors in the context
…tion" Test Plan: Existing tests should pass [ghstack-poisoned]
functorch/_src/aot_autograd.py
Outdated
func_code = bw_module.code.split('self, ') | ||
# print(func_code[0] + func_code[1]) | ||
exec(func_code[0] + func_code[1], globals()) | ||
f = create_aot_autograd_function(forward, bw_compiler, bw_compiler, partition_fn, aot_decompositions, grad_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two questions:
- Why are we passing
forward
to create_aot_autograd_function? I would have expected us to pass bw_module.code without the self argument - What is the exec for? Are you trying to test this without the
create_aot_autograd_function
line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- forward is the name of the function generated by running
bw_module.code
- exec executes the bw_module.code to create a backward function which is the forward for the next pass
…tion" Test Plan: Existing tests should pass [ghstack-poisoned]
…tion" Test Plan: Existing tests should pass [ghstack-poisoned]
…tion" Test Plan: Existing tests should pass [ghstack-poisoned]
…tion" Test Plan: Existing tests should pass [ghstack-poisoned]
…tion" Test Plan: Existing tests should pass [ghstack-poisoned]
…r order derivatives for aot_function" Test Plan: Existing tests should pass [ghstack-poisoned]
…r order derivatives for aot_function" Test Plan: Existing tests should pass [ghstack-poisoned]
…r order derivatives for aot_function" Test Plan: Existing tests should pass [ghstack-poisoned]
def fake_fn(primals, tangents): | ||
return fx_g_b(primals, tangents) | ||
fx_g_b = make_fx(functionalize(fake_fn))(flat_tensor_args, inp_grad_outs) | ||
saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this approach doesn't always work because the newly generated fx graph may not have the same nodes as the previous graph. We need an alternate way to select nodes of interest in this new graph!
Stack from ghstack (oldest at bottom):
Test Plan: Existing tests should pass