diff --git a/BackendBench/backwards_utils.py b/BackendBench/backwards_utils.py new file mode 100644 index 00000000..71ebdad9 --- /dev/null +++ b/BackendBench/backwards_utils.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for backwards pass checking and gradient verification. +""" + +from typing import List + +import torch + +from BackendBench.scripts.op_map import query + +# Operations that should be exempted from backwards pass testing +BACKWARDS_PASS_TESTING_EXCEMPTIONS = [ + # We skip this op for 2 reasons: + # 1) This op has the args (shape, stride, storage_offset) where storage offset + # would change if a gradient is included in the tensor. Our suites (ie. opinfo) + # assume we are doing inference so storage is set to a bad value here. + # We'd have to write a custom suite for this. + # 2) As this is a tensor manipulation op, it doesn't really make sense to test + # a backwards pass for this yet. + "as_strided.default", + # The function is not differentiable with respect to argument 'running_mean'. + # This input cannot have requires_grad True. + # We likely need to handle this on the suite level. + "native_batch_norm.default", + "_native_batch_norm_legit.default", + "_batch_norm_with_update.default", + "native_batch_norm_backward.default", # in torchbench only + # The function 'soft_margin_loss' is not differentiable with respect to argument 'target'. + # This input cannot have requires_grad True. + "soft_margin_loss.default", + # The function 'multi_margin_loss' is not differentiable with respect to argument 'weight'. + # This input cannot have requires_grad True. + "multi_margin_loss.default", + # This op doesn't have a derivative unless it's defined explicitly. But there isn't a good way of detecting the fact that this op has no derivative. + "nextafter.default", + # This is the only op that does not pass opinfo + aten on backwards passes + # TODO: figure out why + "grid_sampler_2d.default", + # torchbench: gets IMA error when adding in the gradient on B200 + "max_pool2d_with_indices_backward.default", +] + + +def should_check_backwards_for_op(op_name: str, check_backwards: bool = True) -> bool: + """ + Determine if backwards checking should be performed for a given operation. + + Args: + op_name: The name of the operation (e.g., "aten.relu.default") + check_backwards: Whether backwards checking is globally enabled + + Returns: + True if backwards checking should be performed, False otherwise + """ + if not check_backwards: + return False + + # Check if op is in the exemption list + if op_name in BACKWARDS_PASS_TESTING_EXCEMPTIONS: + return False + + # Check if op is inplace (inplace ops are not supported for backwards checking) + op_map_entries = query(op_name) + if len(op_map_entries) == 1 and op_map_entries[0].get("is_inplace", False): + return False + + return True + + +def _apply_to_tensors(obj, tensor_fn, container_fn=None, accumulator=None): + """ + Generic functor to apply operations to tensors in nested data structures. + + Args: + obj: The object to traverse (tensor, list, tuple, dict, or other) + tensor_fn: Function to apply to each tensor. Should have signature (tensor, accumulator) -> Any + container_fn: Optional function to handle container reconstruction. + Signature: (container_type, transformed_items) -> Any + accumulator: Optional accumulator object passed to tensor_fn + + Returns: + Transformed object or None for in-place operations + """ + if isinstance(obj, torch.Tensor): + return tensor_fn(obj, accumulator) + elif isinstance(obj, list): + transformed = [ + _apply_to_tensors(item, tensor_fn, container_fn, accumulator) for item in obj + ] + return container_fn(list, transformed) if container_fn else transformed + elif isinstance(obj, tuple): + transformed = [ + _apply_to_tensors(item, tensor_fn, container_fn, accumulator) for item in obj + ] + return container_fn(tuple, transformed) if container_fn else tuple(transformed) + elif isinstance(obj, dict): + transformed = { + key: _apply_to_tensors(value, tensor_fn, container_fn, accumulator) + for key, value in obj.items() + } + return container_fn(dict, transformed) if container_fn else transformed + else: + # For immutable types or unknown types + return obj + + +def collect_gradients(args, kwargs) -> List[torch.Tensor]: + """ + Collect all gradients from args and kwargs into a flat list. + + Order is well-defined: + 1. Iterate through args in order + - If arg is a tensor with grad, append grad + - If arg is a list/tuple, iterate through elements in order and append tensor grads + 2. Iterate through kwargs in sorted key order + - If kwarg is a tensor with grad, append grad + - If kwarg is a list/tuple, iterate through elements in order and append tensor grads + + Args: + args: The arguments (can contain tensors or lists/tuples of tensors). + kwargs: The keyword arguments (can contain tensors or lists/tuples of tensors). + + Returns: + List of gradients (torch.Tensor) in the order specified above. + Returns empty list if no gradients are found. + """ + gradients = [] + + def collect_grad_fn(tensor, accumulator): + accumulator.append(tensor.grad) + + # Collect from args + for arg in args: + _apply_to_tensors(arg, collect_grad_fn, accumulator=gradients) + + # Collect from kwargs in sorted key order for deterministic ordering + for key in sorted(kwargs.keys()): + _apply_to_tensors(kwargs[key], collect_grad_fn, accumulator=gradients) + + return gradients + + +def make_tensors_require_gradients(args, kwargs): + def make_require_grad_fn(tensor, _): + # check dtype is floating or complex + if tensor.dtype not in [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + torch.complex64, + torch.complex128, + ]: + return + tensor.requires_grad = True + + _apply_to_tensors(args, make_require_grad_fn) + _apply_to_tensors(kwargs, make_require_grad_fn) + + +def clear_gradients(args, kwargs): + def clear_grad_fn(tensor, _): + if tensor.grad is not None: + tensor.grad = None + + _apply_to_tensors(args, clear_grad_fn) + _apply_to_tensors(kwargs, clear_grad_fn) diff --git a/BackendBench/eval.py b/BackendBench/eval.py index c88ee37f..44e2b427 100644 --- a/BackendBench/eval.py +++ b/BackendBench/eval.py @@ -6,12 +6,17 @@ import logging import math +import time import traceback from dataclasses import dataclass from typing import List, Tuple import torch +from BackendBench.backwards_utils import ( + clear_gradients, + collect_gradients, +) from BackendBench.utils import compute_errors, serialize_args, uses_cuda_stream @@ -26,6 +31,8 @@ class CorrectnessTestResult: max_abs_error: float = -math.inf max_rel_error: float = -math.inf test_type: str = "correctness" + has_correct_gradients: bool = False + checked_backwards: bool = False @dataclass @@ -90,25 +97,89 @@ def allclose(a, b, atol=1e-2, rtol=1e-2): return False -def eval_correctness_test(op, impl, test) -> CorrectnessTestResult: +def compare_gradients(res_grad, ref_grad, atol=1e-2, rtol=1e-2): + if res_grad is None and ref_grad is None: + return True + if res_grad is None or ref_grad is None: + raise ValueError("One of the gradients is None while the other is not.") + return allclose(res_grad, ref_grad, atol=atol, rtol=rtol) + + +def _check_if_output_has_backwards(output): + if isinstance(output, torch.Tensor): + # todo: ask why we have to do this and why isinstance(output.grad_fn, NotImplementedType) doesn't work for outputs of ops with no derivative like floor_divide.default + has_grad_fn = not (type(output.grad_fn).__name__ == "NotImplemented") + return output.requires_grad and has_grad_fn + elif isinstance(output, list) or isinstance(output, tuple): + return all(_check_if_output_has_backwards(x) for x in output) and len(output) > 0 + else: + return False + + +def _compute_loss(output): + if isinstance(output, torch.Tensor): + return output.sum() + elif isinstance(output, list) or isinstance(output, tuple): + return sum(_compute_loss(x) for x in output) + else: + raise ValueError(f"Unsupported type: {type(output)}") + + +def eval_correctness_test(op, impl, test, check_backwards=False) -> CorrectnessTestResult: """Evaluate impl of op against test. Returns: Tuple of (is_correct, error_message, absolute_error, relative_error) """ + + # Get the test_backwards flag from the test object if it exists + # The suite is responsible for setting this based on op capabilities + test_backwards = getattr(test, "test_backwards", False) + + # Combine with global check_backwards flag + check_backwards = check_backwards and test_backwards + args, kwargs = test.args, test.kwargs ref = op(*args, **kwargs) + + # we now modify check_backwards with another check. Specifically that ref is something that has gradients (aka returns a torch.tensor or a collection of torch.tensors as we cannot perform a backwards pass otherwise) + backwards_possible = _check_if_output_has_backwards(ref) + + check_backwards = backwards_possible and check_backwards + if check_backwards: + loss = _compute_loss(ref) + loss.backward() + ref_grads = collect_gradients(args, kwargs) + clear_gradients(args, kwargs) + else: + ref_grads = None + try: res = impl(*args, **kwargs) + if check_backwards: + loss = _compute_loss(res) + loss.backward() + res_grads = collect_gradients(args, kwargs) + clear_gradients(args, kwargs) + has_correct_gradients = compare_gradients(ref_grads, res_grads) + else: + res_grads = None + has_correct_gradients = False is_correct = allclose(ref, res) abs_error, rel_error = compute_errors(ref, res) + if check_backwards and not has_correct_gradients: + raise ValueError( + f"Gradients are not correct for {op.__name__} with args {serialize_args(args, kwargs)}" + ) result = CorrectnessTestResult( op_name=op.__name__, args=serialize_args(args, kwargs), is_correct=is_correct, max_abs_error=abs_error, max_rel_error=rel_error, + has_correct_gradients=has_correct_gradients, + checked_backwards=check_backwards, ) return result except Exception as e: @@ -125,14 +196,16 @@ def eval_correctness_test(op, impl, test) -> CorrectnessTestResult: return result -def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult]]: +def eval_correctness( + op, impl, tests, check_backwards=False +) -> Tuple[float, List[CorrectnessTestResult]]: """Evaluate correctness of impl against tests.""" correct, total = 0, 0 test_results: List[CorrectnessTestResult] = [] for test in tests: args_str = serialize_args(test.args, test.kwargs) logging.debug(f"Testing {op.__name__} with args {args_str}") - result = eval_correctness_test(op, impl, test) + result = eval_correctness_test(op, impl, test, check_backwards) test_results.append(result) if result.is_correct: correct += 1 @@ -148,7 +221,6 @@ def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult def cpu_bench(fn, num_runs=100): """Simple CPU benchmarking using time.perf_counter.""" - import time for _ in range(10): fn() @@ -164,6 +236,7 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult bench_fn = ( triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench ) + base_times = [] test_times = [] args_strs = [] @@ -176,6 +249,12 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult args_str = serialize_args(cached_args, cached_kwargs) args_strs.append(args_str) logging.debug(f"Benchmarking {op.__name__} with args {args_str}") + # Warmup: run both operations to compile CUDA kernels and warm up caches + for _ in range(25): + _ = op(*cached_args, **cached_kwargs) + _ = impl(*cached_args, **cached_kwargs) + if torch.cuda.is_available(): + torch.cuda.synchronize() base_time = bench_fn(lambda: op(*cached_args, **cached_kwargs)) base_times.append(base_time) # Note: If the test fails we consider the speedup to be 1.0 @@ -225,7 +304,7 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult def eval_one_op( - op, impl, correctness_tests, performance_tests + op, impl, correctness_tests, performance_tests, check_backwards=False ) -> Tuple[float, float, List[CorrectnessTestResult], List[PerformanceTestResult]]: """Evaluate impl of op against correctness_tests and performance_tests. @@ -261,7 +340,9 @@ def eval_one_op( ) return 0, 1.0, correctness_results, performance_results - correctness_score, correctness_results = eval_correctness(op, impl, correctness_tests) + correctness_score, correctness_results = eval_correctness( + op, impl, correctness_tests, check_backwards + ) performance_score, performance_results = eval_performance(op, impl, performance_tests) return ( correctness_score, diff --git a/BackendBench/multiprocessing_eval.py b/BackendBench/multiprocessing_eval.py index 09f86116..c24d36a4 100644 --- a/BackendBench/multiprocessing_eval.py +++ b/BackendBench/multiprocessing_eval.py @@ -43,6 +43,7 @@ class EvalTask: correctness_tests: List[Any] performance_tests: List[Any] device: str + check_backwards: bool = False @dataclass @@ -116,7 +117,13 @@ def test_to_device_iterator(tests, device): performance_score, correctness_results, performance_results, - ) = eval_one_op(op, impl, correctness_tests, performance_tests) + ) = eval_one_op( + op, + impl, + correctness_tests, + performance_tests, + check_backwards=task.check_backwards, + ) result = EvalResult( task_id=task.task_id, correctness_score=correctness_score, @@ -239,7 +246,9 @@ def __init__(self, num_workers: int = 1): logger.info(f"Initialized MultiprocessingEvaluator with {num_workers} workers") - def submit_task(self, op, impl, correctness_tests, performance_tests) -> int: + def submit_task( + self, op, impl, correctness_tests, performance_tests, check_backwards=False + ) -> int: task_id = self.next_task_id self.next_task_id += 1 if not is_pickleable(op): @@ -276,6 +285,7 @@ def submit_task(self, op, impl, correctness_tests, performance_tests) -> int: correctness_tests=cpu_correctness_tests, performance_tests=cpu_performance_tests, device=str(orig_device), + check_backwards=check_backwards, ) self.task_queue.put(task) diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 3039e11a..51960221 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -148,6 +148,12 @@ def setup_logging(log_level): type=click.Choice(["triton", "pytorch", "cutedsl"]), help="Which DSL to use for LLM backend", ) +@click.option( + "--check-backwards", + default=False, + is_flag=True, + help="Check gradients of the result and reference", +) def cli( log_level, suite, @@ -166,6 +172,7 @@ def cli( check_overhead_dominated_ops, p, dsl, + check_backwards, ): if suite != "torchbench": if topn_inputs is not None: @@ -184,6 +191,7 @@ def cli( "cuda", torch.bfloat16, filter=ops, + check_backwards=check_backwards, ), "torchbench": lambda: TorchBenchTestSuite( "torchbench", @@ -191,6 +199,7 @@ def cli( filter=ops, topn=topn_inputs, check_overhead_dominated_ops=check_overhead_dominated_ops, + check_backwards=check_backwards, ), "facto": lambda: FactoTestSuite( "facto_cuda_bfloat16", @@ -231,6 +240,11 @@ def cli( timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") log_dir = f"backendbench_output_{timestamp}" + if check_backwards: + assert backend_name == "directory" or backend_name == "aten", ( + "check-backwards is only supported for directory backend or aten backend (for smoketests)" + ) + overall_correctness = [] overall_performance = [] all_correctness_results = [] @@ -248,6 +262,7 @@ def cli( backend[test.op], test.correctness_tests, test.performance_tests, + check_backwards=check_backwards, ) overall_correctness.append(all(result.is_correct for result in correctness_results)) @@ -270,6 +285,7 @@ def cli( backend[test.op], test.correctness_tests, test.performance_tests, + check_backwards=check_backwards, ) # Start evaluation @@ -299,6 +315,23 @@ def cli( f"perf@p score (rate of correct samples with a speedup greater than p, p={p}): {perf_at_p_score:.2f}" ) + if check_backwards: + backwards_correctness = ( + torch.tensor( + [ + result.has_correct_gradients + for result in all_correctness_results + if result.checked_backwards + ] + ) + .float() + .mean() + .item() + ) + print( + f"backwards correctness score (mean pass rate over all operators which support backwards): {backwards_correctness:.2f}" + ) + command = "python -m BackendBench.scripts.main " + " ".join(sys.argv[1:]) # Save results if not disabled diff --git a/BackendBench/suite/base.py b/BackendBench/suite/base.py index 1f5fe635..71ec4800 100644 --- a/BackendBench/suite/base.py +++ b/BackendBench/suite/base.py @@ -6,9 +6,10 @@ class Test: - def __init__(self, *args, **kwargs): + def __init__(self, *args, test_backwards=False, **kwargs): self._args = args self._kwargs = kwargs + self.test_backwards = test_backwards @property def args(self): diff --git a/BackendBench/suite/opinfo.py b/BackendBench/suite/opinfo.py index 611c5063..bef3d3c6 100644 --- a/BackendBench/suite/opinfo.py +++ b/BackendBench/suite/opinfo.py @@ -10,6 +10,10 @@ from torch.testing._internal.common_methods_invocations import op_db from torch.utils._python_dispatch import TorchDispatchMode +from BackendBench.backwards_utils import ( + make_tensors_require_gradients, + should_check_backwards_for_op, +) from BackendBench.eval import allclose from .base import OpTest, TestSuite @@ -18,24 +22,33 @@ class OpInfoTest: - def __init__(self, *args, **kwargs): + def __init__(self, *args, test_backwards=False, **kwargs): self.args = args self.kwargs = kwargs + self.test_backwards = test_backwards class OpInfoOpTest(OpTest): - def __init__(self, op, correctness_tests, indices): + def __init__(self, op, correctness_tests, indices, check_backwards=False): self.op = op self._correctness_tests = correctness_tests self.indices = set(indices) self.performance_tests = [] + self._check_backwards = check_backwards @property def correctness_tests(self): + # Determine if this op should check backwards + test_backwards = should_check_backwards_for_op(self.op.__name__, self._check_backwards) + for idx, test in enumerate(self._correctness_tests): if idx in self.indices: # print(f"{idx} {test.input=} {test.args=} {test.kwargs=}") - yield OpInfoTest(test.input, *test.args, **test.kwargs) + if test_backwards: + make_tensors_require_gradients(test.args, test.kwargs) + yield OpInfoTest( + test.input, *test.args, test_backwards=test_backwards, **test.kwargs + ) class OpTracerMode(TorchDispatchMode): @@ -48,10 +61,11 @@ def __torch_dispatch__(self, fn, types, args=(), kwargs={}): self.ops.append(fn) self.args.append(args) self.kwargs.append(kwargs) + return fn(*args, **kwargs) -def build_op_tests(device, dtype, filter=None): +def build_op_tests(device, dtype, filter=None, check_backwards=False): op_info_op_tests = [] for op in op_db: if filter and op.name not in filter: @@ -85,11 +99,13 @@ def build_op_tests(device, dtype, filter=None): for overload, indices in op_indices.items(): if len(indices) > 0: - op_info_op_tests.append(OpInfoOpTest(overload, sample_inputs, indices)) + op_info_op_tests.append( + OpInfoOpTest(overload, sample_inputs, indices, check_backwards) + ) return op_info_op_tests class OpInfoTestSuite(TestSuite): - def __init__(self, name, device, dtype, filter=None): - super().__init__(name, build_op_tests(device, dtype, filter)) + def __init__(self, name, device, dtype, filter=None, check_backwards=False): + super().__init__(name, build_op_tests(device, dtype, filter, check_backwards)) diff --git a/BackendBench/suite/torchbench.py b/BackendBench/suite/torchbench.py index 2ee3d698..c45a0166 100644 --- a/BackendBench/suite/torchbench.py +++ b/BackendBench/suite/torchbench.py @@ -27,6 +27,10 @@ import torch # noqa: F401 +from BackendBench.backwards_utils import ( + make_tensors_require_gradients, + should_check_backwards_for_op, +) from BackendBench.data_loaders import ( _args_size, load_ops_from_source, @@ -37,16 +41,18 @@ class TorchBenchTest: - def __init__(self, *args, **kwargs): + def __init__(self, *args, test_backwards=False, **kwargs): self.args = args self.kwargs = kwargs + self.test_backwards = test_backwards class TorchBenchOpTest: - def __init__(self, op, inputs, topn): + def __init__(self, op, inputs, topn, check_backwards=False): self.op = eval(f"torch.ops.{op}") self.inputs = inputs self.topn = topn + self._check_backwards = check_backwards def tests(self): inputs_and_sizes = [] @@ -59,9 +65,14 @@ def tests(self): @property def correctness_tests(self): + # Determine if this op should check backwards + test_backwards = should_check_backwards_for_op(self.op.__name__, self._check_backwards) + for inp in self.tests(): args, kwargs = deserialize_args(inp) - yield TorchBenchTest(*args, **kwargs) + if test_backwards: + make_tensors_require_gradients(args, kwargs) + yield TorchBenchTest(*args, test_backwards=test_backwards, **kwargs) @property def performance_tests(self): @@ -78,9 +89,11 @@ def __init__( filter=None, topn=None, check_overhead_dominated_ops=False, + check_backwards=False, ): self.name = name self.topn = topn + self.check_backwards = check_backwards # Load operations using the shared data loader ops_list = load_ops_from_source( source=filename, @@ -102,4 +115,4 @@ def __iter__(self): for op, inputs in self.optests.items(): if any(s in op for s in UNSUPPORTED_OPERATORS): continue - yield TorchBenchOpTest(op, inputs, self.topn) + yield TorchBenchOpTest(op, inputs, self.topn, self.check_backwards) diff --git a/test/test_gradient_checks.py b/test/test_gradient_checks.py new file mode 100644 index 00000000..7cea2183 --- /dev/null +++ b/test/test_gradient_checks.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from BackendBench.backwards_utils import make_tensors_require_gradients +from BackendBench.eval import ( + _check_if_output_has_backwards, + clear_gradients, + collect_gradients, + eval_correctness_test, +) + + +class TestCollectGradients: + """Test the collect_gradients function.""" + + def test_collect_gradients_single_tensor(self): + """Test collecting gradients from a single tensor.""" + x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + y = x.sum() + y.backward() + + grads = collect_gradients([x], {}) + assert len(grads) == 1 + assert torch.allclose(grads[0], torch.ones(3)) + + def test_collect_gradients_multiple_tensors(self): + """Test collecting gradients from multiple tensors.""" + x = torch.tensor([1.0, 2.0], requires_grad=True) + y = torch.tensor([3.0, 4.0], requires_grad=True) + z = x.sum() + y.sum() + z.backward() + + grads = collect_gradients([x, y], {}) + assert len(grads) == 2 + assert torch.allclose(grads[0], torch.ones(2)) + assert torch.allclose(grads[1], torch.ones(2)) + + def test_collect_gradients_nested_list(self): + """Test collecting gradients from nested lists.""" + x = torch.tensor([1.0], requires_grad=True) + y = torch.tensor([2.0], requires_grad=True) + z = torch.tensor([3.0], requires_grad=True) + loss = (x + y + z).sum() + loss.backward() + + grads = collect_gradients([[x, y], z], {}) + assert len(grads) == 3 + for grad in grads: + assert torch.allclose(grad, torch.ones(1)) + + def test_collect_gradients_no_grad(self): + """Test collecting when tensors have no gradients.""" + x = torch.tensor([1.0, 2.0], requires_grad=True) + y = torch.tensor([3.0, 4.0], requires_grad=True) + # No backward call, so no gradients + + grads = collect_gradients([x, y], {}) + assert len(grads) == 2 + assert grads[0] is None + assert grads[1] is None + + +class TestMakeTensorsRequireGradients: + """Test the make_tensors_require_gradients function.""" + + def test_make_tensors_require_grad(self): + """Test that integer tensors don't get requires_grad.""" + x = torch.tensor([1, 2, 3]) # int tensor + y = torch.tensor([1.0, 2.0, 3.0]) # float tensor + + make_tensors_require_gradients([x, y], {}) + + assert not x.requires_grad # int tensors can't require grad + assert y.requires_grad + + +class TestClearGradients: + """Test the clear_gradients function.""" + + def test_clear_gradients_single(self): + """Test clearing gradient from single tensor.""" + x = torch.tensor([1.0, 2.0], requires_grad=True) + y = x.sum() + y.backward() + + assert x.grad is not None + + clear_gradients([x], {}) + + assert x.grad is None + + def test_clear_gradients_no_grad(self): + """Test clearing when there are no gradients.""" + x = torch.tensor([1.0], requires_grad=True) + + # Should not raise error + clear_gradients([x], {}) + + assert x.grad is None + + +class TestCheckIfOutputHasBackwards: + """Test the _check_if_output_has_backwards function.""" + + def test_check_tensor_with_grad_fn(self): + """Test tensor with grad_fn.""" + x = torch.tensor([1.0], requires_grad=True) + y = x * 2 + + assert _check_if_output_has_backwards(y) + + def test_check_tensor_without_grad_fn(self): + """Test tensor without grad_fn.""" + x = torch.tensor([1.0], requires_grad=False) + + assert not _check_if_output_has_backwards(x) + + +class TestEvalCorrectnessWithBackwards: + """Integration tests for eval_correctness_test with backwards checking.""" + + def test_eval_correctness_without_backwards(self): + """Test correctness evaluation without backwards checking.""" + op = torch.ops.aten.relu.default + impl = torch.ops.aten.relu.default + + class TestCase: + def __init__(self, args, kwargs): + self.args = args + self.kwargs = kwargs + self.test_backwards = False + + test = TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) + + result = eval_correctness_test(op, impl, test, check_backwards=False) + + assert result.is_correct + assert not result.checked_backwards + assert not result.has_correct_gradients + + def test_eval_correctness_backwards(self): + """Test backwards checking with multiple inputs.""" + op = torch.ops.aten.add.Tensor + impl = torch.ops.aten.add.Tensor + + class TestCase: + def __init__(self, args, kwargs): + self.args = args + self.kwargs = kwargs + self.test_backwards = True + + test = TestCase([torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])], {}) + + make_tensors_require_gradients(test.args, test.kwargs) + + result = eval_correctness_test(op, impl, test, check_backwards=True) + + assert result.is_correct + assert result.checked_backwards + assert result.has_correct_gradients