Skip to content

Commit 64e72d6

Browse files
committed
[ModelSuite] Add model ops coverage validation test
This PR adds another unit test to the model loading / config system in the last PR. Specifically, here we ensure that the ops specified in the config are run in the model itself. This is important as updates to torch could change how backwards passes could work. Furthermore, if we are expecting folks to write kernels for a set of ops and then run the model, we should guarentee those ops are used.
1 parent bc2575b commit 64e72d6

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed

test/test_model_ops_coverage.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Unit test to verify that models actually invoke all operators declared in their configs.
9+
10+
This test validates that:
11+
1. Forward pass invokes all operators in config["ops"]["forward"]
12+
2. Backward pass invokes all operators in config["ops"]["backward"]
13+
3. Clear error messages indicate which operators are missing per model
14+
"""
15+
16+
import os
17+
import re
18+
import unittest
19+
from typing import Dict, Set
20+
21+
import torch
22+
23+
from BackendBench.suite.model import load_models
24+
25+
26+
class OpTracker:
27+
"""Track operators called during forward/backward passes using torch profiler."""
28+
29+
def __init__(self):
30+
self.called_ops: Set[str] = set()
31+
self.profiler = None
32+
33+
def __enter__(self):
34+
self.called_ops.clear()
35+
36+
# Use torch profiler to track ops
37+
self.profiler = torch.profiler.profile(
38+
activities=[torch.profiler.ProfilerActivity.CPU],
39+
record_shapes=False,
40+
with_stack=False,
41+
)
42+
self.profiler.__enter__()
43+
return self
44+
45+
def __exit__(self, *args):
46+
self.profiler.__exit__(*args)
47+
48+
# Extract op names from profiler events
49+
for event in self.profiler.events():
50+
event_name = event.name
51+
# Look for aten operations
52+
if "::" in event_name:
53+
# Handle format like "aten::convolution" or "aten::convolution.default"
54+
parts = event_name.replace("::", ".").split(".")
55+
56+
if len(parts) >= 2 and parts[0] == "aten":
57+
if len(parts) == 2:
58+
# No variant specified, add .default
59+
op_name = f"{parts[0]}.{parts[1]}.default"
60+
else:
61+
# Keep as is
62+
op_name = event_name.replace("::", ".")
63+
64+
self.called_ops.add(op_name)
65+
66+
67+
class TestModelOpsCoverage(unittest.TestCase):
68+
"""Test that models invoke all operators declared in their configs."""
69+
70+
def test_all_models_ops_coverage(self):
71+
"""Test that all models invoke their declared forward and backward ops."""
72+
models_dir = os.path.join(
73+
os.path.dirname(os.path.dirname(__file__)),
74+
"BackendBench",
75+
"suite",
76+
"models",
77+
)
78+
79+
models = load_models(models_dir=models_dir)
80+
self.assertGreater(len(models), 0, "Should load at least one model")
81+
82+
failures = []
83+
84+
for model_dict in models:
85+
model_name = model_dict["name"]
86+
model_class = model_dict["class"]
87+
config = model_dict["config"]
88+
89+
# Get expected ops from config
90+
config_ops = config.get("ops", {})
91+
expected_forward = set(config_ops.get("forward", []))
92+
expected_backward = set(config_ops.get("backward", []))
93+
94+
# Skip if no ops to check
95+
if not expected_forward and not expected_backward:
96+
continue
97+
98+
try:
99+
# Initialize model
100+
model_config = config.get("model_config", {})
101+
init_args = model_config.get("init_args", {})
102+
103+
if model_config.get("requires_init_seed"):
104+
torch.manual_seed(42)
105+
106+
model = model_class(**init_args)
107+
108+
# Get a test input from model_tests
109+
model_tests = config.get("model_tests", {})
110+
if not model_tests:
111+
failures.append(f"{model_name}: No model_tests in config")
112+
continue
113+
114+
# Use first test case
115+
test_name = list(model_tests.keys())[0]
116+
test_args_str = model_tests[test_name]
117+
118+
# Parse test args (simple eval for now)
119+
# Format: "([], {'x': T([2, 3, 32, 32], f32)})"
120+
test_input = self._create_test_input_from_string(test_args_str)
121+
122+
# Track forward pass
123+
tracker = OpTracker()
124+
with tracker:
125+
output = model(**test_input)
126+
127+
forward_ops = tracker.called_ops
128+
129+
# Check forward ops coverage
130+
missing_forward = expected_forward - forward_ops
131+
if missing_forward:
132+
failures.append(
133+
f"{model_name} [FORWARD]: Missing ops: {sorted(missing_forward)}"
134+
)
135+
136+
# Track backward pass
137+
if expected_backward:
138+
# Ensure output requires grad
139+
for param in model.parameters():
140+
param.requires_grad = True
141+
142+
# Create loss
143+
if isinstance(output, torch.Tensor):
144+
loss = output.sum()
145+
else:
146+
# Handle tuple/dict outputs
147+
loss = sum(v.sum() for v in output.values() if isinstance(v, torch.Tensor))
148+
149+
tracker_backward = OpTracker()
150+
with tracker_backward:
151+
loss.backward()
152+
153+
backward_ops = tracker_backward.called_ops
154+
155+
# Check backward ops coverage
156+
missing_backward = expected_backward - backward_ops
157+
if missing_backward:
158+
failures.append(
159+
f"{model_name} [BACKWARD]: Missing ops: {sorted(missing_backward)}"
160+
)
161+
162+
except Exception as e:
163+
failures.append(f"{model_name}: Error during test: {e}")
164+
165+
# Report all failures at once
166+
if failures:
167+
error_msg = "\n\nOperator Coverage Failures:\n" + "\n".join(
168+
f" - {failure}" for failure in failures
169+
)
170+
self.fail(error_msg)
171+
172+
def _create_test_input_from_string(self, test_args_str: str) -> Dict[str, torch.Tensor]:
173+
"""Parse test input string into actual tensors.
174+
175+
Format: "([], {'x': T([2, 3, 32, 32], f32)})"
176+
"""
177+
178+
# Extract tensor specs: T([shape], dtype)
179+
tensor_pattern = r"'(\w+)':\s*T\(\[([\d,\s]+)\],\s*(\w+)\)"
180+
matches = re.findall(tensor_pattern, test_args_str)
181+
182+
inputs = {}
183+
for name, shape_str, dtype_str in matches:
184+
shape = [int(x.strip()) for x in shape_str.split(",")]
185+
186+
# Map dtype string to torch dtype
187+
dtype_map = {
188+
"f32": torch.float32,
189+
"f64": torch.float64,
190+
"i32": torch.int32,
191+
"i64": torch.int64,
192+
}
193+
dtype = dtype_map.get(dtype_str, torch.float32)
194+
195+
inputs[name] = torch.randn(*shape, dtype=dtype)
196+
197+
return inputs
198+
199+
200+
if __name__ == "__main__":
201+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)