Skip to content

Commit bc2575b

Browse files
committed
[ModelSuite] Add model loading infrastructure
Here we introduce model suite (model.py). The idea here to start and codify the ideas from jiannanWang/BackendBenchExamples. Specifically this PR deals with model registration and loading. ### Model Registration This PR creates a way of adding models to the suite and automatically validates them through CI. It also loads the models as well. The way these models are added is detailed in this readme. The tl;dir is we use a format similar to kernelbench and SakanaAI/robust-kbench where we pair model code with a config. Importantly the configs contain initialization code, forward pass arguments (both in a similar format to torchbench), and a list of ops in the forward and backwards passes. These ops are fairly important as they are what we want to point out to the researcher when they are optimizing a model. There is a README.md to help folks setup proper model code / configs. We also further verify these registrations are correct through CI. Specifically we run test/test_model_ops_configs.py to ensure the configs are formatted correctly. ### Added models I added 2 models here as examples. SmokeTestModel - This is simple model that uses aten.ops.mm as we can implement a correct version of this op ToyCoreOpsModel - This is a model which explicitly calls the backwards passes which are both in torchbench + core. ### Small Things - Added a --model-filter to the CLI as it will be needed to support filtering in model suite as it chooses things to test based on the model not set of ops ### Testing New tests are added so pytest resolves things here
1 parent 65b7c1a commit bc2575b

File tree

10 files changed

+884
-2
lines changed

10 files changed

+884
-2
lines changed

BackendBench/scripts/main.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from BackendBench.output import save_results
2020
from BackendBench.suite import (
2121
FactoTestSuite,
22+
ModelSuite,
2223
OpInfoTestSuite,
2324
SmokeTestSuite,
2425
TorchBenchTestSuite,
@@ -50,7 +51,7 @@ def setup_logging(log_level):
5051
@click.option(
5152
"--suite",
5253
default="smoke",
53-
type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]),
54+
type=click.Choice(["smoke", "opinfo", "torchbench", "facto", "model"]),
5455
help="Which suite to run",
5556
)
5657
@click.option(
@@ -63,7 +64,13 @@ def setup_logging(log_level):
6364
"--ops",
6465
default=None,
6566
type=str,
66-
help="Comma-separated list of ops to run",
67+
help="Comma-separated list of ops to run (not supported for model suite)",
68+
)
69+
@click.option(
70+
"--model-filter",
71+
default=None,
72+
type=str,
73+
help="Comma-separated list of models to run (only for model suite)",
6774
)
6875
@click.option(
6976
"--topn-inputs",
@@ -147,6 +154,7 @@ def cli(
147154
suite,
148155
backend,
149156
ops,
157+
model_filter,
150158
topn_inputs,
151159
llm_attempts,
152160
llm_model,
@@ -166,9 +174,20 @@ def cli(
166174
if check_overhead_dominated_ops:
167175
raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite")
168176

177+
if suite == "model":
178+
if ops is not None:
179+
raise ValueError(
180+
"--ops filter is not supported for model suite. Use --model-filter instead"
181+
)
182+
183+
if suite != "model" and model_filter is not None:
184+
raise ValueError("--model-filter is only supported for model suite")
185+
169186
setup_logging(log_level)
170187
if ops:
171188
ops = ops.split(",")
189+
if model_filter:
190+
model_filter = model_filter.split(",")
172191

173192
suite = {
174193
"smoke": lambda: SmokeTestSuite,

BackendBench/suite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from .base import OpTest, Test, TestSuite
1717
from .facto import FactoTestSuite
18+
from .model import ModelSuite
1819
from .opinfo import OpInfoTestSuite
1920
from .smoke import randn, SmokeTestSuite
2021
from .torchbench import TorchBenchOpTest, TorchBenchTestSuite
@@ -24,6 +25,7 @@
2425
"OpTest",
2526
"TestSuite",
2627
"FactoTestSuite",
28+
"ModelSuite",
2729
"OpInfoTestSuite",
2830
"SmokeTestSuite",
2931
"randn",

BackendBench/suite/model.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Model Suite for testing models defined in configs.
9+
"""
10+
11+
import importlib.util
12+
import json
13+
import logging
14+
import os
15+
from typing import Any, Dict, List, Optional
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def load_models(
21+
models_dir: str = "models", filter: Optional[List[str]] = None
22+
) -> List[Dict[str, Any]]:
23+
"""Load models using strict naming convention: folder_name/folder_name.py + folder_name.json
24+
25+
Args:
26+
models_dir: Directory containing models (default: "models")
27+
filter: Optional list of model names to load. If None, loads all models.
28+
29+
Returns:
30+
List of dictionaries with keys:
31+
- name: Model name (str)
32+
- class: Model class (type)
33+
- config: Configuration dictionary from JSON file
34+
"""
35+
models = []
36+
37+
if not os.path.exists(models_dir):
38+
raise FileNotFoundError(f"Models directory not found: {models_dir}")
39+
40+
for model_name in os.listdir(models_dir):
41+
model_dir = os.path.join(models_dir, model_name)
42+
if not os.path.isdir(model_dir):
43+
continue
44+
45+
# Skip if not in filter
46+
if filter is not None and model_name not in filter:
47+
continue
48+
49+
# Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json
50+
model_file = os.path.join(model_dir, f"{model_name}.py")
51+
config_file = os.path.join(model_dir, f"{model_name}.json")
52+
53+
# Check both files exist
54+
if not os.path.exists(model_file):
55+
raise FileNotFoundError(f"Model file not found: {model_file}")
56+
57+
if not os.path.exists(config_file):
58+
raise FileNotFoundError(f"Config file not found: {config_file}")
59+
60+
try:
61+
# Load config
62+
with open(config_file, "r") as f:
63+
config = json.load(f)
64+
65+
# Load model class dynamically
66+
spec = importlib.util.spec_from_file_location(model_name, model_file)
67+
module = importlib.util.module_from_spec(spec)
68+
spec.loader.exec_module(module)
69+
70+
# Find model class (must match model_name exactly)
71+
if not hasattr(module, model_name):
72+
raise RuntimeError(f"Model class '{model_name}' not found in {model_file}")
73+
74+
model_class = getattr(module, model_name)
75+
if not (isinstance(model_class, type) and hasattr(model_class, "forward")):
76+
raise RuntimeError(f"'{model_name}' in {model_file} is not a valid model class")
77+
78+
models.append({"name": model_name, "class": model_class, "config": config})
79+
logger.info(f"Loaded model: {model_name}")
80+
81+
except Exception as e:
82+
raise RuntimeError(f"Failed to load model {model_name}: {e}")
83+
84+
if filter is not None and len(models) == 0:
85+
raise ValueError(f"No models found matching filter: {filter}")
86+
87+
return models
88+
89+
90+
class ModelSuite:
91+
"""Model Suite for end-to-end model testing."""
92+
93+
def __init__(
94+
self,
95+
name: str = "model",
96+
filter: Optional[List[str]] = None,
97+
):
98+
"""Initialize ModelSuite.
99+
100+
Args:
101+
name: Suite name (default: "model")
102+
filter: Optional list of model names to load
103+
"""
104+
models_dir = os.path.join(os.path.dirname(__file__), "models")
105+
106+
# Load models
107+
models = load_models(models_dir=models_dir, filter=filter)
108+
logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}")
109+
110+
# Store loaded models
111+
self.models = models
112+
self.name = name
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Adding Models to BackendBench
2+
3+
## Quick Start
4+
5+
Models define operator lists and validate that custom backends work correctly in full model execution. Two files required:
6+
7+
```
8+
BackendBench/suite/models/YourModel/
9+
├── YourModel.py # nn.Module class
10+
└── YourModel.json # Configuration
11+
```
12+
13+
**Naming rule:** Directory name = File name = Class name (exact match, case-sensitive)
14+
15+
## Adding a Model
16+
17+
### 1. Create Directory and Files
18+
19+
```bash
20+
cd BackendBench/suite/models
21+
mkdir MyModel
22+
cd MyModel
23+
touch MyModel.py MyModel.json
24+
```
25+
26+
### 2. Write Model Class (`MyModel.py`)
27+
28+
**Requirements:**
29+
- Class name = filename (exact match)
30+
- All `__init__` params need defaults
31+
- Add a main() / runner if you are inclined for sanity checking
32+
- Keep it simple - focus on specific operators you're testing
33+
- Look in this directory for examples
34+
35+
### 3. Write Config (`MyModel.json`)
36+
37+
**Key Fields:**
38+
- `model_config.init_args` - Args for `__init__()`, must match your defaults
39+
- `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten.<op>.default"`)
40+
- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` The format is further described [here](https://huggingface.co/datasets/GPUMODE/backendbench_tests#serialized-arguments-in-backendbench)
41+
- Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc.
42+
- `metadata.description` - What this model tests
43+
- Look in this directory for examples
44+
45+
**Finding operator names:**
46+
```python
47+
from torch.profiler import profile, ProfilerActivity
48+
49+
with profile(activities=[ProfilerActivity.CPU]) as prof:
50+
output = model(x)
51+
loss = output.sum()
52+
loss.backward()
53+
54+
for event in prof.key_averages():
55+
if "aten::" in event.key:
56+
print(event.key)
57+
```
58+
59+
### 4. Test Your Model
60+
61+
```bash
62+
# Test standalone
63+
cd BackendBench/suite/models/MyModel
64+
python MyModel.py # Add main() for standalone testing
65+
66+
# Test with suite
67+
python -m BackendBench.scripts.main \
68+
--suite model \
69+
--backend aten \
70+
--model-filter MyModel
71+
72+
# Expected output:
73+
# Model: MyModel
74+
# Status: ✓ Passed (2/2 tests)
75+
# ✓ small
76+
# ✓ large
77+
```
78+
79+
### 5: Validation
80+
`test/test_model_ops_configs.py` and `test/test_model_ops_coverage.py` are tests that validate that all models are loadable / formatted correctly.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"model_config": {
3+
"init_args": {
4+
"input_dim": 128,
5+
"hidden_dim": 128,
6+
"output_dim": 128
7+
}
8+
},
9+
"ops": {
10+
"forward": [
11+
"aten.mm.default"
12+
],
13+
"backward": [
14+
"aten.mm.default"
15+
]
16+
},
17+
"model_tests": {
18+
"small_batch": "([], {'x': T([2, 128], f32)})",
19+
"medium_batch": "([], {'x': T([16, 128], f32)})",
20+
"large_batch": "([], {'x': T([32, 128], f32)})"
21+
},
22+
"metadata": {
23+
"description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes"
24+
}
25+
}

0 commit comments

Comments
 (0)