Skip to content

Commit c39c753

Browse files
committed
[WIP] Add testing for backwards passes
Summary: Here we add correctness tests for backwards passes of ops. This PR does the following things 1) Figures out which ops not to test. (explained in depth at the top of BackendBench/backwards_utils.py + avoiding inplace ops) For simplcity we are not testing a) in place ops as we cannot just pass in the test args, but need special casing b) ops that require special handling with their args, c) one off corner cases. Every other 2) To do backwards passes (since the tensors naturally don't require grad in our suites), right now we add a gradient to all tensors in args and kwargs. This logic (+ test for if we should even run a backwards pass) is put in the suite as this can be handled on a per test level. For example in a follow up PR for this, we can add a backwards pass column in the torchbench dataset. 3) We also compare gradients and clear gradients after use to validate the backwards pass. We use the same allclose function as before. 4) There are also a bunch of unit tests added to make sure the gradient checking utils work as expected. Test Plan: With this really slow correctish [mm implementation](https://gist.github.com/PaliC/e62859f0286f6bfa338ccb4140e9e74f) we get ```bash uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards ... correctness score (mean pass rate over all operators): 1.00 performance score (geomean speedup over all operators): 0.00 perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00 backwards correctness score (mean pass rate over all operators which support backwards): 1.00 ``` With the bad monkey patched implementation we get ``` uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards ... correctness score (mean pass rate over all operators): 0.00 performance score (geomean speedup over all operators): 1.00 perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00 backwards correctness score (mean pass rate over all operators which support backwards): 0.00 ``` The following two commands with aten also work as expected (100% correctness on forwards and backwards) ``` ``uv run python BackendBench/scripts/main.py --suite opinfo --backend aten --check-backwards`` `uv run python BackendBench/scripts/main.py --suite torchbench --topn 2 --backend aten --check-backwards` ``` Todo: - [ ] rename is_correct -> correct_output (originally in this pr but added noise for reviewers) - [ ] performance tests - [ ] for torchbench suite put backwards checking in dataset - [ ] Assuming the above support ops which have conditions on their args - [ ] support inplace ops
1 parent 6161729 commit c39c753

File tree

8 files changed

+510
-20
lines changed

8 files changed

+510
-20
lines changed

BackendBench/backwards_utils.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Utilities for backwards pass checking and gradient verification.
9+
"""
10+
11+
from typing import List
12+
13+
import torch
14+
15+
from BackendBench.scripts.op_map import query
16+
17+
# Operations that should be exempted from backwards pass testing
18+
BACKWARDS_PASS_TESTING_EXCEMPTIONS = [
19+
# We skip this op for 2 reasons:
20+
# 1) This op has the args (shape, stride, storage_offset) where storage offset
21+
# would change if a gradient is included in the tensor. Our suites (ie. opinfo)
22+
# assume we are doing inference so storage is set to a bad value here.
23+
# We'd have to write a custom suite for this.
24+
# 2) As this is a tensor manipulation op, it doesn't really make sense to test
25+
# a backwards pass for this yet.
26+
"as_strided.default",
27+
# The function <op_name> is not differentiable with respect to argument 'running_mean'.
28+
# This input cannot have requires_grad True.
29+
# We likely need to handle this on the suite level.
30+
"native_batch_norm.default",
31+
"_native_batch_norm_legit.default",
32+
"_batch_norm_with_update.default",
33+
"native_batch_norm_backward.default", # in torchbench only
34+
# The function 'soft_margin_loss' is not differentiable with respect to argument 'target'.
35+
# This input cannot have requires_grad True.
36+
"soft_margin_loss.default",
37+
# The function 'multi_margin_loss' is not differentiable with respect to argument 'weight'.
38+
# This input cannot have requires_grad True.
39+
"multi_margin_loss.default",
40+
# This op doesn't have a derivative unless it's defined explicitly. But there isn't a good way of detecting the fact that this op has no derivative.
41+
"nextafter.default",
42+
# This is the only op that does not pass opinfo + aten on backwards passes
43+
# TODO: figure out why
44+
"grid_sampler_2d.default",
45+
# torchbench: gets IMA error when adding in the gradient on B200
46+
"max_pool2d_with_indices_backward.default",
47+
]
48+
49+
50+
def should_check_backwards_for_op(op_name: str, check_backwards: bool = True) -> bool:
51+
"""
52+
Determine if backwards checking should be performed for a given operation.
53+
54+
Args:
55+
op_name: The name of the operation (e.g., "aten.relu.default")
56+
check_backwards: Whether backwards checking is globally enabled
57+
58+
Returns:
59+
True if backwards checking should be performed, False otherwise
60+
"""
61+
if not check_backwards:
62+
return False
63+
64+
# Check if op is in the exemption list
65+
if op_name in BACKWARDS_PASS_TESTING_EXCEMPTIONS:
66+
return False
67+
68+
# Check if op is inplace (inplace ops are not supported for backwards checking)
69+
op_map_entries = query(op_name)
70+
if len(op_map_entries) == 1 and op_map_entries[0].get("is_inplace", False):
71+
return False
72+
73+
return True
74+
75+
76+
def _apply_to_tensors(obj, tensor_fn, container_fn=None, accumulator=None):
77+
"""
78+
Generic functor to apply operations to tensors in nested data structures.
79+
80+
Args:
81+
obj: The object to traverse (tensor, list, tuple, dict, or other)
82+
tensor_fn: Function to apply to each tensor. Should have signature (tensor, accumulator) -> Any
83+
container_fn: Optional function to handle container reconstruction.
84+
Signature: (container_type, transformed_items) -> Any
85+
accumulator: Optional accumulator object passed to tensor_fn
86+
87+
Returns:
88+
Transformed object or None for in-place operations
89+
"""
90+
if isinstance(obj, torch.Tensor):
91+
return tensor_fn(obj, accumulator)
92+
elif isinstance(obj, list):
93+
transformed = [
94+
_apply_to_tensors(item, tensor_fn, container_fn, accumulator) for item in obj
95+
]
96+
return container_fn(list, transformed) if container_fn else transformed
97+
elif isinstance(obj, tuple):
98+
transformed = [
99+
_apply_to_tensors(item, tensor_fn, container_fn, accumulator) for item in obj
100+
]
101+
return container_fn(tuple, transformed) if container_fn else tuple(transformed)
102+
elif isinstance(obj, dict):
103+
transformed = {
104+
key: _apply_to_tensors(value, tensor_fn, container_fn, accumulator)
105+
for key, value in obj.items()
106+
}
107+
return container_fn(dict, transformed) if container_fn else transformed
108+
else:
109+
# For immutable types or unknown types
110+
return obj
111+
112+
113+
def collect_gradients(args, kwargs) -> List[torch.Tensor]:
114+
"""
115+
Collect all gradients from args and kwargs into a flat list.
116+
117+
Order is well-defined:
118+
1. Iterate through args in order
119+
- If arg is a tensor with grad, append grad
120+
- If arg is a list/tuple, iterate through elements in order and append tensor grads
121+
2. Iterate through kwargs in sorted key order
122+
- If kwarg is a tensor with grad, append grad
123+
- If kwarg is a list/tuple, iterate through elements in order and append tensor grads
124+
125+
Args:
126+
args: The arguments (can contain tensors or lists/tuples of tensors).
127+
kwargs: The keyword arguments (can contain tensors or lists/tuples of tensors).
128+
129+
Returns:
130+
List of gradients (torch.Tensor) in the order specified above.
131+
Returns empty list if no gradients are found.
132+
"""
133+
gradients = []
134+
135+
def collect_grad_fn(tensor, accumulator):
136+
accumulator.append(tensor.grad)
137+
138+
# Collect from args
139+
for arg in args:
140+
_apply_to_tensors(arg, collect_grad_fn, accumulator=gradients)
141+
142+
# Collect from kwargs in sorted key order for deterministic ordering
143+
for key in sorted(kwargs.keys()):
144+
_apply_to_tensors(kwargs[key], collect_grad_fn, accumulator=gradients)
145+
146+
return gradients
147+
148+
149+
def make_tensors_require_gradients(args, kwargs):
150+
def make_require_grad_fn(tensor, _):
151+
# check dtype is floating or complex
152+
if tensor.dtype not in [
153+
torch.float32,
154+
torch.float64,
155+
torch.float16,
156+
torch.bfloat16,
157+
torch.complex64,
158+
torch.complex128,
159+
]:
160+
return
161+
tensor.requires_grad = True
162+
163+
_apply_to_tensors(args, make_require_grad_fn)
164+
_apply_to_tensors(kwargs, make_require_grad_fn)
165+
166+
167+
def clear_gradients(args, kwargs):
168+
def clear_grad_fn(tensor, _):
169+
if tensor.grad is not None:
170+
tensor.grad = None
171+
172+
_apply_to_tensors(args, clear_grad_fn)
173+
_apply_to_tensors(kwargs, clear_grad_fn)

BackendBench/eval.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66

77
import logging
88
import math
9+
import time
910
import traceback
1011
from dataclasses import dataclass
1112
from typing import List, Tuple
1213

1314
import torch
1415

16+
from BackendBench.backwards_utils import (
17+
clear_gradients,
18+
collect_gradients,
19+
)
1520
from BackendBench.utils import compute_errors, serialize_args, uses_cuda_stream
1621

1722

@@ -26,6 +31,8 @@ class CorrectnessTestResult:
2631
max_abs_error: float = -math.inf
2732
max_rel_error: float = -math.inf
2833
test_type: str = "correctness"
34+
has_correct_gradients: bool = False
35+
checked_backwards: bool = False
2936

3037

3138
@dataclass
@@ -90,25 +97,89 @@ def allclose(a, b, atol=1e-2, rtol=1e-2):
9097
return False
9198

9299

93-
def eval_correctness_test(op, impl, test) -> CorrectnessTestResult:
100+
def compare_gradients(res_grad, ref_grad, atol=1e-2, rtol=1e-2):
101+
if res_grad is None and ref_grad is None:
102+
return True
103+
if res_grad is None or ref_grad is None:
104+
raise ValueError("One of the gradients is None while the other is not.")
105+
return allclose(res_grad, ref_grad, atol=atol, rtol=rtol)
106+
107+
108+
def _check_if_output_has_backwards(output):
109+
if isinstance(output, torch.Tensor):
110+
# todo: ask why we have to do this and why isinstance(output.grad_fn, NotImplementedType) doesn't work for outputs of ops with no derivative like floor_divide.default
111+
has_grad_fn = not (type(output.grad_fn).__name__ == "NotImplemented")
112+
return output.requires_grad and has_grad_fn
113+
elif isinstance(output, list) or isinstance(output, tuple):
114+
return all(_check_if_output_has_backwards(x) for x in output) and len(output) > 0
115+
else:
116+
return False
117+
118+
119+
def _compute_loss(output):
120+
if isinstance(output, torch.Tensor):
121+
return output.sum()
122+
elif isinstance(output, list) or isinstance(output, tuple):
123+
return sum(_compute_loss(x) for x in output)
124+
else:
125+
raise ValueError(f"Unsupported type: {type(output)}")
126+
127+
128+
def eval_correctness_test(op, impl, test, check_backwards=False) -> CorrectnessTestResult:
94129
"""Evaluate impl of op against test.
95130
96131
Returns:
97132
Tuple of (is_correct, error_message, absolute_error, relative_error)
98133
"""
134+
135+
# Get the test_backwards flag from the test object if it exists
136+
# The suite is responsible for setting this based on op capabilities
137+
test_backwards = getattr(test, "test_backwards", False)
138+
139+
# Combine with global check_backwards flag
140+
check_backwards = check_backwards and test_backwards
141+
99142
args, kwargs = test.args, test.kwargs
100143
ref = op(*args, **kwargs)
144+
145+
# we now modify check_backwards with another check. Specifically that ref is something that has gradients (aka returns a torch.tensor or a collection of torch.tensors as we cannot perform a backwards pass otherwise)
146+
backwards_possible = _check_if_output_has_backwards(ref)
147+
148+
check_backwards = backwards_possible and check_backwards
149+
if check_backwards:
150+
loss = _compute_loss(ref)
151+
loss.backward()
152+
ref_grads = collect_gradients(args, kwargs)
153+
clear_gradients(args, kwargs)
154+
else:
155+
ref_grads = None
156+
101157
try:
102158
res = impl(*args, **kwargs)
159+
if check_backwards:
160+
loss = _compute_loss(res)
161+
loss.backward()
162+
res_grads = collect_gradients(args, kwargs)
163+
clear_gradients(args, kwargs)
164+
has_correct_gradients = compare_gradients(ref_grads, res_grads)
165+
else:
166+
res_grads = None
167+
has_correct_gradients = False
103168
is_correct = allclose(ref, res)
104169

105170
abs_error, rel_error = compute_errors(ref, res)
171+
if check_backwards and not has_correct_gradients:
172+
raise ValueError(
173+
f"Gradients are not correct for {op.__name__} with args {serialize_args(args, kwargs)}"
174+
)
106175
result = CorrectnessTestResult(
107176
op_name=op.__name__,
108177
args=serialize_args(args, kwargs),
109178
is_correct=is_correct,
110179
max_abs_error=abs_error,
111180
max_rel_error=rel_error,
181+
has_correct_gradients=has_correct_gradients,
182+
checked_backwards=check_backwards,
112183
)
113184
return result
114185
except Exception as e:
@@ -125,14 +196,16 @@ def eval_correctness_test(op, impl, test) -> CorrectnessTestResult:
125196
return result
126197

127198

128-
def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult]]:
199+
def eval_correctness(
200+
op, impl, tests, check_backwards=False
201+
) -> Tuple[float, List[CorrectnessTestResult]]:
129202
"""Evaluate correctness of impl against tests."""
130203
correct, total = 0, 0
131204
test_results: List[CorrectnessTestResult] = []
132205
for test in tests:
133206
args_str = serialize_args(test.args, test.kwargs)
134207
logging.debug(f"Testing {op.__name__} with args {args_str}")
135-
result = eval_correctness_test(op, impl, test)
208+
result = eval_correctness_test(op, impl, test, check_backwards)
136209
test_results.append(result)
137210
if result.is_correct:
138211
correct += 1
@@ -148,7 +221,6 @@ def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult
148221

149222
def cpu_bench(fn, num_runs=100):
150223
"""Simple CPU benchmarking using time.perf_counter."""
151-
import time
152224

153225
for _ in range(10):
154226
fn()
@@ -164,6 +236,7 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult
164236
bench_fn = (
165237
triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench
166238
)
239+
167240
base_times = []
168241
test_times = []
169242
args_strs = []
@@ -176,6 +249,12 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult
176249
args_str = serialize_args(cached_args, cached_kwargs)
177250
args_strs.append(args_str)
178251
logging.debug(f"Benchmarking {op.__name__} with args {args_str}")
252+
# Warmup: run both operations to compile CUDA kernels and warm up caches
253+
for _ in range(25):
254+
_ = op(*cached_args, **cached_kwargs)
255+
_ = impl(*cached_args, **cached_kwargs)
256+
if torch.cuda.is_available():
257+
torch.cuda.synchronize()
179258
base_time = bench_fn(lambda: op(*cached_args, **cached_kwargs))
180259
base_times.append(base_time)
181260
# Note: If the test fails we consider the speedup to be 1.0
@@ -225,7 +304,7 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult
225304

226305

227306
def eval_one_op(
228-
op, impl, correctness_tests, performance_tests
307+
op, impl, correctness_tests, performance_tests, check_backwards=False
229308
) -> Tuple[float, float, List[CorrectnessTestResult], List[PerformanceTestResult]]:
230309
"""Evaluate impl of op against correctness_tests and performance_tests.
231310
@@ -261,7 +340,9 @@ def eval_one_op(
261340
)
262341
return 0, 1.0, correctness_results, performance_results
263342

264-
correctness_score, correctness_results = eval_correctness(op, impl, correctness_tests)
343+
correctness_score, correctness_results = eval_correctness(
344+
op, impl, correctness_tests, check_backwards
345+
)
265346
performance_score, performance_results = eval_performance(op, impl, performance_tests)
266347
return (
267348
correctness_score,

0 commit comments

Comments
 (0)