-
Notifications
You must be signed in to change notification settings - Fork 15
Add testing for backwards passes #191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
0f3f721 to
c39c753
Compare
Summary: Here we add correctness tests for backwards passes of ops. This PR does the following things 1) Figures out which ops not to test. (explained in depth at the top of BackendBench/backwards_utils.py + avoiding inplace ops) For simplcity we are not testing a) in place ops as we cannot just pass in the test args, but need special casing b) ops that require special handling with their args, c) one off corner cases. Every other 2) To do backwards passes (since the tensors naturally don't require grad in our suites), right now we add a gradient to all tensors in args and kwargs. This logic (+ test for if we should even run a backwards pass) is put in the suite as this can be handled on a per test level. For example in a follow up PR for this, we can add a backwards pass column in the torchbench dataset. 3) We also compare gradients and clear gradients after use to validate the backwards pass. We use the same allclose function as before. Note we don't copy tensors/args as sometimes they are views (at least in opinfo) which makes cloning difficult. 4) There are also a bunch of unit tests added to make sure the gradient checking utils work as expected. Test Plan: With this really slow correctish [mm implementation](https://gist.github.com/PaliC/e62859f0286f6bfa338ccb4140e9e74f) we get ```bash uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards ... correctness score (mean pass rate over all operators): 1.00 performance score (geomean speedup over all operators): 0.00 perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00 backwards correctness score (mean pass rate over all operators which support backwards): 1.00 ``` With the bad monkey patched implementation we get ``` uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards ... correctness score (mean pass rate over all operators): 0.00 performance score (geomean speedup over all operators): 1.00 perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00 backwards correctness score (mean pass rate over all operators which support backwards): 0.00 ``` The following two commands with aten also work as expected (100% correctness on forwards and backwards) ``` ``uv run python BackendBench/scripts/main.py --suite opinfo --backend aten --check-backwards`` `uv run python BackendBench/scripts/main.py --suite torchbench --topn 2 --backend aten --check-backwards` ``` Todo: - [ ] rename is_correct -> correct_output (originally in this pr but added noise for reviewers) - [ ] performance tests - [ ] for torchbench suite put backwards checking in dataset - [ ] Assuming the above support ops which have conditions on their args - [ ] support inplace ops
|
Hi @PaliC! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Summary:
Here we add correctness tests for backwards passes of ops.
This PR does the following things
Figures out which ops not to test. (explained in depth at the top of BackendBench/backwards_utils.py + avoiding inplace ops) For simplcity we are not testing a) in place ops as we cannot just pass in the test args, but need special casing b) ops that require special handling with their args, c) one off corner cases. Every other
To do backwards passes (since the tensors naturally don't require grad in our suites), right now we add a gradient to all tensors in args and kwargs. This logic (+ test for if we should even run a backwards pass) is put in the suite as this can be handled on a per test level. For example in a follow up PR for this, we can add a backwards pass column in the torchbench dataset.
We also compare gradients and clear gradients after use to validate the backwards pass. We use the same allclose function as before. Note we don't copy tensors/args as sometimes they are views (at least in opinfo) which makes cloning difficult.
There are also a bunch of unit tests added to make sure the gradient checking utils work as expected.
Test Plan:
With this really slow correctish mm implementation we get
uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards ... correctness score (mean pass rate over all operators): 1.00 performance score (geomean speedup over all operators): 0.00 perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00 backwards correctness score (mean pass rate over all operators which support backwards): 1.00With the bad monkey patched implementation we get
The following two commands with aten also work as expected (100% correctness on forwards and backwards)
Todo: