Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions BackendBench/backwards_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Utilities for backwards pass checking and gradient verification.
"""

from typing import List

import torch

from BackendBench.scripts.op_map import query

# Operations that should be exempted from backwards pass testing
BACKWARDS_PASS_TESTING_EXCEMPTIONS = [
# We skip this op for 2 reasons:
# 1) This op has the args (shape, stride, storage_offset) where storage offset
# would change if a gradient is included in the tensor. Our suites (ie. opinfo)
# assume we are doing inference so storage is set to a bad value here.
# We'd have to write a custom suite for this.
# 2) As this is a tensor manipulation op, it doesn't really make sense to test
# a backwards pass for this yet.
"as_strided.default",
# The function <op_name> is not differentiable with respect to argument 'running_mean'.
# This input cannot have requires_grad True.
# We likely need to handle this on the suite level.
"native_batch_norm.default",
"_native_batch_norm_legit.default",
"_batch_norm_with_update.default",
"native_batch_norm_backward.default", # in torchbench only
# The function 'soft_margin_loss' is not differentiable with respect to argument 'target'.
# This input cannot have requires_grad True.
"soft_margin_loss.default",
# The function 'multi_margin_loss' is not differentiable with respect to argument 'weight'.
# This input cannot have requires_grad True.
"multi_margin_loss.default",
# This op doesn't have a derivative unless it's defined explicitly. But there isn't a good way of detecting the fact that this op has no derivative.
"nextafter.default",
# This is the only op that does not pass opinfo + aten on backwards passes
# TODO: figure out why
"grid_sampler_2d.default",
# torchbench: gets IMA error when adding in the gradient on B200
"max_pool2d_with_indices_backward.default",
]


def should_check_backwards_for_op(op_name: str, check_backwards: bool = True) -> bool:
"""
Determine if backwards checking should be performed for a given operation.

Args:
op_name: The name of the operation (e.g., "aten.relu.default")
check_backwards: Whether backwards checking is globally enabled

Returns:
True if backwards checking should be performed, False otherwise
"""
if not check_backwards:
return False

# Check if op is in the exemption list
if op_name in BACKWARDS_PASS_TESTING_EXCEMPTIONS:
return False

# Check if op is inplace (inplace ops are not supported for backwards checking)
op_map_entries = query(op_name)
if len(op_map_entries) == 1 and op_map_entries[0].get("is_inplace", False):
return False

return True


def _apply_to_tensors(obj, tensor_fn, container_fn=None, accumulator=None):
"""
Generic functor to apply operations to tensors in nested data structures.

Args:
obj: The object to traverse (tensor, list, tuple, dict, or other)
tensor_fn: Function to apply to each tensor. Should have signature (tensor, accumulator) -> Any
container_fn: Optional function to handle container reconstruction.
Signature: (container_type, transformed_items) -> Any
accumulator: Optional accumulator object passed to tensor_fn

Returns:
Transformed object or None for in-place operations
"""
if isinstance(obj, torch.Tensor):
return tensor_fn(obj, accumulator)
elif isinstance(obj, list):
transformed = [
_apply_to_tensors(item, tensor_fn, container_fn, accumulator) for item in obj
]
return container_fn(list, transformed) if container_fn else transformed
elif isinstance(obj, tuple):
transformed = [
_apply_to_tensors(item, tensor_fn, container_fn, accumulator) for item in obj
]
return container_fn(tuple, transformed) if container_fn else tuple(transformed)
elif isinstance(obj, dict):
transformed = {
key: _apply_to_tensors(value, tensor_fn, container_fn, accumulator)
for key, value in obj.items()
}
return container_fn(dict, transformed) if container_fn else transformed
else:
# For immutable types or unknown types
return obj


def collect_gradients(args, kwargs) -> List[torch.Tensor]:
"""
Collect all gradients from args and kwargs into a flat list.

Order is well-defined:
1. Iterate through args in order
- If arg is a tensor with grad, append grad
- If arg is a list/tuple, iterate through elements in order and append tensor grads
2. Iterate through kwargs in sorted key order
- If kwarg is a tensor with grad, append grad
- If kwarg is a list/tuple, iterate through elements in order and append tensor grads

Args:
args: The arguments (can contain tensors or lists/tuples of tensors).
kwargs: The keyword arguments (can contain tensors or lists/tuples of tensors).

Returns:
List of gradients (torch.Tensor) in the order specified above.
Returns empty list if no gradients are found.
"""
gradients = []

def collect_grad_fn(tensor, accumulator):
accumulator.append(tensor.grad)

# Collect from args
for arg in args:
_apply_to_tensors(arg, collect_grad_fn, accumulator=gradients)

# Collect from kwargs in sorted key order for deterministic ordering
for key in sorted(kwargs.keys()):
_apply_to_tensors(kwargs[key], collect_grad_fn, accumulator=gradients)

return gradients


def make_tensors_require_gradients(args, kwargs):
def make_require_grad_fn(tensor, _):
# check dtype is floating or complex
if tensor.dtype not in [
torch.float32,
torch.float64,
torch.float16,
torch.bfloat16,
torch.complex64,
torch.complex128,
]:
return
tensor.requires_grad = True

_apply_to_tensors(args, make_require_grad_fn)
_apply_to_tensors(kwargs, make_require_grad_fn)


def clear_gradients(args, kwargs):
def clear_grad_fn(tensor, _):
if tensor.grad is not None:
tensor.grad = None

_apply_to_tensors(args, clear_grad_fn)
_apply_to_tensors(kwargs, clear_grad_fn)
93 changes: 87 additions & 6 deletions BackendBench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

import logging
import math
import time
import traceback
from dataclasses import dataclass
from typing import List, Tuple

import torch

from BackendBench.backwards_utils import (
clear_gradients,
collect_gradients,
)
from BackendBench.utils import compute_errors, serialize_args, uses_cuda_stream


Expand All @@ -26,6 +31,8 @@ class CorrectnessTestResult:
max_abs_error: float = -math.inf
max_rel_error: float = -math.inf
test_type: str = "correctness"
has_correct_gradients: bool = False
checked_backwards: bool = False


@dataclass
Expand Down Expand Up @@ -90,25 +97,89 @@ def allclose(a, b, atol=1e-2, rtol=1e-2):
return False


def eval_correctness_test(op, impl, test) -> CorrectnessTestResult:
def compare_gradients(res_grad, ref_grad, atol=1e-2, rtol=1e-2):
if res_grad is None and ref_grad is None:
return True
if res_grad is None or ref_grad is None:
raise ValueError("One of the gradients is None while the other is not.")
return allclose(res_grad, ref_grad, atol=atol, rtol=rtol)


def _check_if_output_has_backwards(output):
if isinstance(output, torch.Tensor):
# todo: ask why we have to do this and why isinstance(output.grad_fn, NotImplementedType) doesn't work for outputs of ops with no derivative like floor_divide.default
has_grad_fn = not (type(output.grad_fn).__name__ == "NotImplemented")
return output.requires_grad and has_grad_fn
elif isinstance(output, list) or isinstance(output, tuple):
return all(_check_if_output_has_backwards(x) for x in output) and len(output) > 0
else:
return False


def _compute_loss(output):
if isinstance(output, torch.Tensor):
return output.sum()
elif isinstance(output, list) or isinstance(output, tuple):
return sum(_compute_loss(x) for x in output)
else:
raise ValueError(f"Unsupported type: {type(output)}")


def eval_correctness_test(op, impl, test, check_backwards=False) -> CorrectnessTestResult:
"""Evaluate impl of op against test.

Returns:
Tuple of (is_correct, error_message, absolute_error, relative_error)
"""

# Get the test_backwards flag from the test object if it exists
# The suite is responsible for setting this based on op capabilities
test_backwards = getattr(test, "test_backwards", False)

# Combine with global check_backwards flag
check_backwards = check_backwards and test_backwards

args, kwargs = test.args, test.kwargs
ref = op(*args, **kwargs)

# we now modify check_backwards with another check. Specifically that ref is something that has gradients (aka returns a torch.tensor or a collection of torch.tensors as we cannot perform a backwards pass otherwise)
backwards_possible = _check_if_output_has_backwards(ref)

check_backwards = backwards_possible and check_backwards
if check_backwards:
loss = _compute_loss(ref)
loss.backward()
ref_grads = collect_gradients(args, kwargs)
clear_gradients(args, kwargs)
else:
ref_grads = None

try:
res = impl(*args, **kwargs)
if check_backwards:
loss = _compute_loss(res)
loss.backward()
res_grads = collect_gradients(args, kwargs)
clear_gradients(args, kwargs)
has_correct_gradients = compare_gradients(ref_grads, res_grads)
else:
res_grads = None
has_correct_gradients = False
is_correct = allclose(ref, res)

abs_error, rel_error = compute_errors(ref, res)
if check_backwards and not has_correct_gradients:
raise ValueError(
f"Gradients are not correct for {op.__name__} with args {serialize_args(args, kwargs)}"
)
result = CorrectnessTestResult(
op_name=op.__name__,
args=serialize_args(args, kwargs),
is_correct=is_correct,
max_abs_error=abs_error,
max_rel_error=rel_error,
has_correct_gradients=has_correct_gradients,
checked_backwards=check_backwards,
)
return result
except Exception as e:
Expand All @@ -125,14 +196,16 @@ def eval_correctness_test(op, impl, test) -> CorrectnessTestResult:
return result


def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult]]:
def eval_correctness(
op, impl, tests, check_backwards=False
) -> Tuple[float, List[CorrectnessTestResult]]:
"""Evaluate correctness of impl against tests."""
correct, total = 0, 0
test_results: List[CorrectnessTestResult] = []
for test in tests:
args_str = serialize_args(test.args, test.kwargs)
logging.debug(f"Testing {op.__name__} with args {args_str}")
result = eval_correctness_test(op, impl, test)
result = eval_correctness_test(op, impl, test, check_backwards)
test_results.append(result)
if result.is_correct:
correct += 1
Expand All @@ -148,7 +221,6 @@ def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult

def cpu_bench(fn, num_runs=100):
"""Simple CPU benchmarking using time.perf_counter."""
import time

for _ in range(10):
fn()
Expand All @@ -164,6 +236,7 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult
bench_fn = (
triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench
)

base_times = []
test_times = []
args_strs = []
Expand All @@ -176,6 +249,12 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult
args_str = serialize_args(cached_args, cached_kwargs)
args_strs.append(args_str)
logging.debug(f"Benchmarking {op.__name__} with args {args_str}")
# Warmup: run both operations to compile CUDA kernels and warm up caches
for _ in range(25):
_ = op(*cached_args, **cached_kwargs)
_ = impl(*cached_args, **cached_kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
base_time = bench_fn(lambda: op(*cached_args, **cached_kwargs))
base_times.append(base_time)
# Note: If the test fails we consider the speedup to be 1.0
Expand Down Expand Up @@ -225,7 +304,7 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult


def eval_one_op(
op, impl, correctness_tests, performance_tests
op, impl, correctness_tests, performance_tests, check_backwards=False
) -> Tuple[float, float, List[CorrectnessTestResult], List[PerformanceTestResult]]:
"""Evaluate impl of op against correctness_tests and performance_tests.

Expand Down Expand Up @@ -261,7 +340,9 @@ def eval_one_op(
)
return 0, 1.0, correctness_results, performance_results

correctness_score, correctness_results = eval_correctness(op, impl, correctness_tests)
correctness_score, correctness_results = eval_correctness(
op, impl, correctness_tests, check_backwards
)
performance_score, performance_results = eval_performance(op, impl, performance_tests)
return (
correctness_score,
Expand Down
Loading