Commit c39c753
committed
[WIP] Add testing for backwards passes
Summary:
Here we add correctness tests for backwards passes of ops.
This PR does the following things
1) Figures out which ops not to test. (explained in depth at the top of BackendBench/backwards_utils.py + avoiding inplace ops) For simplcity we are not testing a) in place ops as we cannot just pass in the test args, but need special casing b) ops that require special handling with their args, c) one off corner cases. Every other
2) To do backwards passes (since the tensors naturally don't require grad in our suites), right now we add a gradient to all tensors in args and kwargs. This logic (+ test for if we should even run a backwards pass) is put in the suite as this can be handled on a per test level. For example in a follow up PR for this, we can add a backwards pass column in the torchbench dataset.
3) We also compare gradients and clear gradients after use to validate the backwards pass. We use the same allclose function as before.
4) There are also a bunch of unit tests added to make sure the gradient checking utils work as expected.
Test Plan:
With this really slow correctish [mm implementation](https://gist.github.com/PaliC/e62859f0286f6bfa338ccb4140e9e74f) we get
```bash
uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards
...
correctness score (mean pass rate over all operators): 1.00
performance score (geomean speedup over all operators): 0.00
perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00
backwards correctness score (mean pass rate over all operators which support backwards): 1.00
```
With the bad monkey patched implementation we get
```
uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards
...
correctness score (mean pass rate over all operators): 0.00
performance score (geomean speedup over all operators): 1.00
perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00
backwards correctness score (mean pass rate over all operators which support backwards): 0.00
```
The following two commands with aten also work as expected (100% correctness on forwards and backwards)
```
``uv run python BackendBench/scripts/main.py --suite opinfo --backend aten --check-backwards``
`uv run python BackendBench/scripts/main.py --suite torchbench --topn 2 --backend aten --check-backwards`
```
Todo:
- [ ] rename is_correct -> correct_output (originally in this pr but added noise for reviewers)
- [ ] performance tests
- [ ] for torchbench suite put backwards checking in dataset
- [ ] Assuming the above support ops which have conditions on their args
- [ ] support inplace ops1 parent 6161729 commit c39c753
File tree
8 files changed
+510
-20
lines changed- BackendBench
- scripts
- suite
- test
8 files changed
+510
-20
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
| 9 | + | |
9 | 10 | | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
15 | 20 | | |
16 | 21 | | |
17 | 22 | | |
| |||
26 | 31 | | |
27 | 32 | | |
28 | 33 | | |
| 34 | + | |
| 35 | + | |
29 | 36 | | |
30 | 37 | | |
31 | 38 | | |
| |||
90 | 97 | | |
91 | 98 | | |
92 | 99 | | |
93 | | - | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
94 | 129 | | |
95 | 130 | | |
96 | 131 | | |
97 | 132 | | |
98 | 133 | | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
99 | 142 | | |
100 | 143 | | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
101 | 157 | | |
102 | 158 | | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
103 | 168 | | |
104 | 169 | | |
105 | 170 | | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
106 | 175 | | |
107 | 176 | | |
108 | 177 | | |
109 | 178 | | |
110 | 179 | | |
111 | 180 | | |
| 181 | + | |
| 182 | + | |
112 | 183 | | |
113 | 184 | | |
114 | 185 | | |
| |||
125 | 196 | | |
126 | 197 | | |
127 | 198 | | |
128 | | - | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
129 | 202 | | |
130 | 203 | | |
131 | 204 | | |
132 | 205 | | |
133 | 206 | | |
134 | 207 | | |
135 | | - | |
| 208 | + | |
136 | 209 | | |
137 | 210 | | |
138 | 211 | | |
| |||
148 | 221 | | |
149 | 222 | | |
150 | 223 | | |
151 | | - | |
152 | 224 | | |
153 | 225 | | |
154 | 226 | | |
| |||
164 | 236 | | |
165 | 237 | | |
166 | 238 | | |
| 239 | + | |
167 | 240 | | |
168 | 241 | | |
169 | 242 | | |
| |||
176 | 249 | | |
177 | 250 | | |
178 | 251 | | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
179 | 258 | | |
180 | 259 | | |
181 | 260 | | |
| |||
225 | 304 | | |
226 | 305 | | |
227 | 306 | | |
228 | | - | |
| 307 | + | |
229 | 308 | | |
230 | 309 | | |
231 | 310 | | |
| |||
261 | 340 | | |
262 | 341 | | |
263 | 342 | | |
264 | | - | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
265 | 346 | | |
266 | 347 | | |
267 | 348 | | |
| |||
0 commit comments