Skip to content

Commit 0ef64c6

Browse files
committed
[Model Suite] Add model correctness testing
This PR adds end to end model correctness testing testing to Model suite by comparing the outputs and gradients (after a backwards pass) with 1 iteration of the model. We also integrate it into CI. ### Testing Running `uv run python BackendBench/scripts/main.py --suite model --backend directory` with a working mm kernel and a watermarked kernel for everything else yeilds ```bash [2025-10-02 07:16:13][INFO][main.py] ============================================================ [2025-10-02 07:16:13][INFO][main.py] MODEL EVALUATION RESULTS [2025-10-02 07:16:13][INFO][main.py] ============================================================ [2025-10-02 07:16:13][INFO][model.py] Model: ToyCoreOpsModel [2025-10-02 07:16:13][INFO][model.py] Status: ✗ Failed (0/3 tests) [2025-10-02 07:16:13][INFO][model.py] ✗ small_batch [2025-10-02 07:16:13][INFO][model.py] Error: Model ToyCoreOpsModel::small_batch failed: Expected number of channels in input to be divisible by num_groups, but got input of shape [2, 3, 32, 32] and num_groups=8 [2025-10-02 07:16:13][INFO][model.py] ✗ medium_batch [2025-10-02 07:16:13][INFO][model.py] Error: Model ToyCoreOpsModel::medium_batch failed: Expected number of channels in input to be divisible by num_groups, but got input of shape [4, 3, 64, 64] and num_groups=8 [2025-10-02 07:16:13][INFO][model.py] ✗ large_input [2025-10-02 07:16:13][INFO][model.py] Error: Model ToyCoreOpsModel::large_input failed: Expected number of channels in input to be divisible by num_groups, but got input of shape [2, 3, 128, 128] and num_groups=8 [2025-10-02 07:16:13][INFO][model.py] Model: SmokeTestModel [2025-10-02 07:16:13][INFO][model.py] Status: ✓ Passed (3/3 tests) [2025-10-02 07:16:13][INFO][model.py] ✓ small_batch [2025-10-02 07:16:13][INFO][model.py] Output match: ✓ Gradients match: ✓ (4 gradients) [2025-10-02 07:16:13][INFO][model.py] ✓ medium_batch [2025-10-02 07:16:13][INFO][model.py] Output match: ✓ Gradients match: ✓ (4 gradients) [2025-10-02 07:16:13][INFO][model.py] ✓ large_batch [2025-10-02 07:16:13][INFO][model.py] Output match: ✓ Gradients match: ✓ (4 gradients) [2025-10-02 07:16:13][INFO][main.py] ============================================================ ``` ### Future work with Model Suite #181
1 parent 07fca91 commit 0ef64c6

File tree

3 files changed

+356
-2
lines changed

3 files changed

+356
-2
lines changed

BackendBench/eval_model.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Model-level evaluation utilities for testing full model correctness."""
8+
9+
import logging
10+
import random
11+
import traceback
12+
from contextlib import nullcontext
13+
from dataclasses import dataclass
14+
from typing import Any, Dict, List, Tuple
15+
16+
import torch
17+
18+
import BackendBench
19+
from BackendBench.eval import allclose
20+
from BackendBench.utils import deserialize_args
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
@dataclass
26+
class ModelCorrectnessTestResult:
27+
"""Result from testing a model configuration."""
28+
29+
model_name: str
30+
test_name: str
31+
is_correct: bool = False
32+
error_msg: str = ""
33+
error_type: str = ""
34+
traceback: str = ""
35+
output_match: bool = False
36+
gradients_match: bool = False
37+
num_gradients: int = 0
38+
39+
40+
def eval_model_correctness_test(
41+
model_name: str,
42+
model_class: type,
43+
model_config: Dict[str, Any],
44+
test_name: str,
45+
test_args: str,
46+
kernel_dir: str = None,
47+
atol: float = 1e-2,
48+
rtol: float = 1e-2,
49+
) -> ModelCorrectnessTestResult:
50+
"""Evaluate model correctness by comparing eager vs backend execution.
51+
52+
Similar to eval_correctness_test in eval.py, but for full models instead of individual ops.
53+
54+
Args:
55+
model_name: Name of the model being tested
56+
model_class: Model class to instantiate
57+
model_config: Model configuration dict with init_args
58+
test_name: Name of this test configuration
59+
test_args: Serialized arguments string for forward pass
60+
kernel_dir: Optional directory containing kernels for backend
61+
atol: Absolute tolerance for allclose
62+
rtol: Relative tolerance for allclose
63+
64+
Returns:
65+
ModelCorrectnessTestResult with detailed comparison results
66+
"""
67+
try:
68+
# Generate a single seed to use for both eager and backend runs
69+
# This ensures both runs use the same model initialization
70+
seed = random.randint(0, 2**32 - 1)
71+
72+
# Run in eager mode (reference)
73+
eager_out, eager_grads = _run_model(
74+
model_class,
75+
model_config,
76+
test_args,
77+
backend_enabled=False,
78+
kernel_dir=None,
79+
seed=seed,
80+
)
81+
82+
# Run with backend (implementation)
83+
backend_out, backend_grads = _run_model(
84+
model_class,
85+
model_config,
86+
test_args,
87+
backend_enabled=True,
88+
kernel_dir=kernel_dir,
89+
seed=seed,
90+
)
91+
92+
# Compare outputs
93+
output_match = allclose(eager_out, backend_out, atol=atol, rtol=rtol)
94+
95+
# Compare gradients
96+
gradients_match = True
97+
if len(eager_grads) != len(backend_grads):
98+
gradients_match = False
99+
else:
100+
for eager_grad, backend_grad in zip(eager_grads, backend_grads):
101+
if not allclose(eager_grad, backend_grad, atol=atol, rtol=rtol):
102+
gradients_match = False
103+
break
104+
105+
is_correct = output_match and gradients_match
106+
107+
return ModelCorrectnessTestResult(
108+
model_name=model_name,
109+
test_name=test_name,
110+
is_correct=is_correct,
111+
output_match=output_match,
112+
gradients_match=gradients_match,
113+
num_gradients=len(eager_grads),
114+
)
115+
116+
except Exception as e:
117+
error_msg = f"Model {model_name}::{test_name} failed: {e}"
118+
logger.error(error_msg)
119+
return ModelCorrectnessTestResult(
120+
model_name=model_name,
121+
test_name=test_name,
122+
is_correct=False,
123+
error_msg=error_msg,
124+
error_type=str(type(e)),
125+
traceback=traceback.format_exc(),
126+
)
127+
128+
129+
def _move_model_to_input_device(
130+
model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any]
131+
) -> torch.nn.Module:
132+
"""Move model to the same device as input tensor.
133+
134+
Args:
135+
model: Model to move
136+
args: Positional arguments list
137+
kwargs: Keyword arguments dict
138+
139+
Returns:
140+
Model on input device (or original model if no input tensor found)
141+
"""
142+
143+
# this is specific to our configs atm, we should generalize this
144+
input_tensor = kwargs["x"]
145+
if input_tensor is not None:
146+
device = input_tensor.device
147+
model = model.to(device)
148+
return model
149+
150+
151+
def _collect_gradients(
152+
model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any]
153+
) -> List[torch.Tensor]:
154+
"""Collect gradients from input and model parameters.
155+
156+
Args:
157+
model: Model with computed gradients
158+
args: Positional arguments list
159+
kwargs: Keyword arguments dict
160+
161+
Returns:
162+
List of gradient tensors [input_grad, param1_grad, ...]
163+
"""
164+
grads = []
165+
166+
# Input gradient - check both args and kwargs
167+
input_grad = None
168+
if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None:
169+
input_grad = args[0].grad
170+
elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None:
171+
input_grad = kwargs["x"].grad
172+
173+
if input_grad is not None:
174+
grads.append(input_grad.clone())
175+
176+
# Parameter gradients
177+
for param in model.parameters():
178+
if param.grad is not None:
179+
grads.append(param.grad.clone())
180+
181+
return grads
182+
183+
184+
def _run_model(
185+
model_class: type,
186+
model_config: Dict[str, Any],
187+
test_args: str,
188+
backend_enabled: bool,
189+
kernel_dir: str = "generated_kernels",
190+
seed: int = None,
191+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
192+
"""Run model with or without backend enabled.
193+
194+
Args:
195+
model_class: Model class to instantiate
196+
model_config: Model configuration dict with init_args
197+
test_args: Serialized arguments string for forward pass
198+
backend_enabled: If True, use BackendBench context manager
199+
kernel_dir: Optional directory containing kernels
200+
seed: Random seed for reproducibility. If None, generates a random seed.
201+
202+
Returns:
203+
Tuple of (output, gradients) where:
204+
- output: Model output tensor (detached)
205+
- gradients: List of gradient tensors [input_grad, param1_grad, ...]
206+
"""
207+
208+
# Generate seed dynamically and set for deterministic behavior
209+
# IMPORTANT: Must set seed BEFORE deserializing args, because deserialization
210+
# may create random tensors!
211+
if seed is None:
212+
seed = random.randint(0, 2**32 - 1)
213+
torch.manual_seed(seed)
214+
215+
# Deserialize test arguments (now uses the seed we just set)
216+
args, kwargs = deserialize_args(test_args)
217+
218+
# Extract model initialization args
219+
init_args = model_config.get("init_args", {}).copy()
220+
221+
# Create fresh model instance
222+
model = model_class(**init_args)
223+
model.train()
224+
225+
# Move model to same device as input
226+
model = _move_model_to_input_device(model, args, kwargs)
227+
ctx = (
228+
BackendBench.BackendBench.enable(kernel_dir=kernel_dir)
229+
if backend_enabled
230+
else nullcontext()
231+
)
232+
# Run forward + backward with or without backend
233+
with ctx:
234+
output = model(*args, **kwargs)
235+
loss = output.sum()
236+
loss.backward()
237+
238+
# Collect gradients
239+
grads = _collect_gradients(model, args, kwargs)
240+
241+
return output.detach(), grads

BackendBench/scripts/main.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ def setup_logging(log_level):
4141
)
4242

4343

44+
# Helper function as model suite gets fleshed out
45+
def _test_full_models(suite, backend):
46+
assert suite.name == "model"
47+
all_results = []
48+
for model in suite.models:
49+
results = suite.eval_model(model, backend)
50+
all_results.append(results)
51+
logger.info("=" * 60)
52+
logger.info("MODEL EVALUATION RESULTS")
53+
logger.info("=" * 60)
54+
for result in all_results:
55+
suite.print_results(result)
56+
logger.info("=" * 60)
57+
58+
4459
@click.command()
4560
@click.option(
4661
"--log-level",
@@ -179,8 +194,6 @@ def cli(
179194
raise ValueError(
180195
"--ops filter is not supported for model suite. Use --model-filter instead"
181196
)
182-
# remove this in later PR as model suite is supported
183-
raise NotImplementedError("Model suite is not supported yet")
184197

185198
if suite != "model" and model_filter is not None:
186199
raise ValueError("--model-filter is only supported for model suite")
@@ -246,6 +259,11 @@ def cli(
246259
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
247260
log_dir = f"backendbench_output_{timestamp}"
248261

262+
if suite.name == "model":
263+
_test_full_models(suite, backend)
264+
# currently model suite does not support op testing so now we're done
265+
return
266+
249267
overall_correctness = []
250268
overall_performance = []
251269
all_correctness_results = []

BackendBench/suite/model.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import os
1515
from typing import Any, Dict, List, Optional
1616

17+
from BackendBench.eval_model import eval_model_correctness_test
18+
1719
logger = logging.getLogger(__name__)
1820

1921

@@ -110,3 +112,96 @@ def __init__(
110112
# Store loaded models
111113
self.models = models
112114
self.name = name
115+
116+
def eval_model(self, model_dict: Dict[str, Any], backend) -> Dict[str, Any]:
117+
"""Run evaluation on a single model.
118+
119+
Args:
120+
model_dict: Dictionary with keys 'name', 'class', 'config'
121+
backend: Backend to use for evaluation
122+
123+
Returns:
124+
Dictionary with evaluation results including correctness and performance
125+
"""
126+
127+
model_class = model_dict["class"]
128+
model_name = model_dict["name"]
129+
config = model_dict["config"]
130+
131+
# Extract model configuration and tests
132+
model_config = config.get("model_config", {})
133+
model_tests = config.get("model_tests", {})
134+
135+
if not model_tests:
136+
return {
137+
"model_name": model_name,
138+
"passed": False,
139+
"error": "No model_tests found in config",
140+
"test_results": [],
141+
}
142+
143+
# Get kernel_dir from backend if available
144+
kernel_dir = getattr(backend, "ops_dir", None)
145+
146+
# Run each test
147+
test_results = []
148+
for test_name, test_args in model_tests.items():
149+
result = eval_model_correctness_test(
150+
model_name=model_name,
151+
model_class=model_class,
152+
model_config=model_config,
153+
test_name=test_name,
154+
test_args=test_args,
155+
kernel_dir=kernel_dir,
156+
)
157+
test_results.append(result)
158+
159+
# Aggregate results
160+
all_passed = all(r.is_correct for r in test_results)
161+
num_passed = sum(1 for r in test_results if r.is_correct)
162+
num_total = len(test_results)
163+
164+
return {
165+
"model_name": model_name,
166+
"passed": all_passed,
167+
"num_passed": num_passed,
168+
"num_total": num_total,
169+
"test_results": test_results,
170+
}
171+
172+
def print_results(self, results: Dict[str, Any]) -> None:
173+
"""Print model evaluation results.
174+
175+
Args:
176+
results: Dictionary with evaluation results from eval_model
177+
"""
178+
model_name = results.get("model_name", "Unknown")
179+
passed = results.get("passed", False)
180+
num_passed = results.get("num_passed", 0)
181+
num_total = results.get("num_total", 0)
182+
183+
logger.info(f"\nModel: {model_name}")
184+
logger.info(
185+
f"Status: {'✓ Passed' if passed else '✗ Failed'} ({num_passed}/{num_total} tests)"
186+
)
187+
188+
# Print details for each test
189+
test_results = results.get("test_results", [])
190+
for result in test_results:
191+
status = "✓" if result.is_correct else "✗"
192+
logger.info(f" {status} {result.test_name}")
193+
194+
if not result.is_correct:
195+
if result.error_msg:
196+
logger.info(f" Error: {result.error_msg}")
197+
else:
198+
# Show what failed
199+
if not result.output_match:
200+
logger.info(" Output mismatch")
201+
if not result.gradients_match:
202+
logger.info(f" Gradient mismatch ({result.num_gradients} gradients)")
203+
else:
204+
# Show success details
205+
logger.info(
206+
f" Output match: ✓ Gradients match: ✓ ({result.num_gradients} gradients)"
207+
)

0 commit comments

Comments
 (0)