Skip to content

Commit 6e6334b

Browse files
committed
[ModelSuite] Add model loading infrastructure
### Model Registration This PR creates a way of adding models to the suite and automatically validates them through CI. It also loads the models as well. The way these models are added is detailed in this readme. The tl;dir is we use a format similar to kernelbench and SakanaAI/robust-kbench where we pair model code with a config. Importantly the configs contain initialization code, forward pass arguments (both in a similar format to torchbench), and a list of ops in the forward and backwards passes. These ops are fairly important as they are what we want to point out to the researcher when they are optimizing a model. There is a README.md to help folks setup proper model code / configs. We also further verify these registrations are correct through CI. Specifically we run test/test_model_ops_configs.py to ensure the configs are formatted correctly. ### Small Things - Added a --model-filter to the CLI as it will be needed to support filtering in model suite as it chooses things to test based on the model not set of ops ### Testing New tests are added so pytest resolves things here ### Future work with Model Suite #181
1 parent daebfc1 commit 6e6334b

File tree

5 files changed

+412
-2
lines changed

5 files changed

+412
-2
lines changed

BackendBench/scripts/main.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from BackendBench.output import save_results
2020
from BackendBench.suite import (
2121
FactoTestSuite,
22+
ModelSuite,
2223
OpInfoTestSuite,
2324
SmokeTestSuite,
2425
TorchBenchTestSuite,
@@ -50,7 +51,7 @@ def setup_logging(log_level):
5051
@click.option(
5152
"--suite",
5253
default="smoke",
53-
type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]),
54+
type=click.Choice(["smoke", "opinfo", "torchbench", "facto", "model"]),
5455
help="Which suite to run",
5556
)
5657
@click.option(
@@ -63,7 +64,13 @@ def setup_logging(log_level):
6364
"--ops",
6465
default=None,
6566
type=str,
66-
help="Comma-separated list of ops to run",
67+
help="Comma-separated list of ops to run (not supported for model suite)",
68+
)
69+
@click.option(
70+
"--model-filter",
71+
default=None,
72+
type=str,
73+
help="Comma-separated list of models to run (only for model suite)",
6774
)
6875
@click.option(
6976
"--topn-inputs",
@@ -147,6 +154,7 @@ def cli(
147154
suite,
148155
backend,
149156
ops,
157+
model_filter,
150158
topn_inputs,
151159
llm_attempts,
152160
llm_model,
@@ -166,9 +174,22 @@ def cli(
166174
if check_overhead_dominated_ops:
167175
raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite")
168176

177+
if suite == "model":
178+
if ops is not None:
179+
raise ValueError(
180+
"--ops filter is not supported for model suite. Use --model-filter instead"
181+
)
182+
# remove this in later PR as model suite is supported
183+
raise NotImplementedError("Model suite is not supported yet")
184+
185+
if suite != "model" and model_filter is not None:
186+
raise ValueError("--model-filter is only supported for model suite")
187+
169188
setup_logging(log_level)
170189
if ops:
171190
ops = ops.split(",")
191+
if model_filter:
192+
model_filter = model_filter.split(",")
172193

173194
suite = {
174195
"smoke": lambda: SmokeTestSuite,
@@ -191,6 +212,7 @@ def cli(
191212
torch.bfloat16,
192213
filter=ops,
193214
),
215+
"model": lambda: ModelSuite(filter=model_filter),
194216
}[suite]()
195217

196218
backend_name = backend

BackendBench/suite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from .base import OpTest, Test, TestSuite
1717
from .facto import FactoTestSuite
18+
from .model import ModelSuite
1819
from .opinfo import OpInfoTestSuite
1920
from .smoke import randn, SmokeTestSuite
2021
from .torchbench import TorchBenchOpTest, TorchBenchTestSuite
@@ -24,6 +25,7 @@
2425
"OpTest",
2526
"TestSuite",
2627
"FactoTestSuite",
28+
"ModelSuite",
2729
"OpInfoTestSuite",
2830
"SmokeTestSuite",
2931
"randn",

BackendBench/suite/model.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Model Suite for testing models defined in configs.
9+
"""
10+
11+
import importlib.util
12+
import json
13+
import logging
14+
import os
15+
from typing import Any, Dict, List, Optional
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def load_models(
21+
models_dir: str = "models", filter: Optional[List[str]] = None
22+
) -> List[Dict[str, Any]]:
23+
"""Load models using strict naming convention: folder_name/folder_name.py + folder_name.json
24+
25+
Args:
26+
models_dir: Directory containing models (default: "models")
27+
filter: Optional list of model names to load. If None, loads all models.
28+
29+
Returns:
30+
List of dictionaries with keys:
31+
- name: Model name (str)
32+
- class: Model class (type)
33+
- config: Configuration dictionary from JSON file
34+
"""
35+
models = []
36+
37+
if not os.path.exists(models_dir):
38+
raise FileNotFoundError(f"Models directory not found: {models_dir}")
39+
40+
for model_name in os.listdir(models_dir):
41+
model_dir = os.path.join(models_dir, model_name)
42+
if not os.path.isdir(model_dir):
43+
continue
44+
45+
# Skip if not in filter
46+
if filter is not None and model_name not in filter:
47+
continue
48+
49+
# Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json
50+
model_file = os.path.join(model_dir, f"{model_name}.py")
51+
config_file = os.path.join(model_dir, f"{model_name}.json")
52+
53+
# Check both files exist
54+
if not os.path.exists(model_file):
55+
raise FileNotFoundError(f"Model file not found: {model_file}")
56+
57+
if not os.path.exists(config_file):
58+
raise FileNotFoundError(f"Config file not found: {config_file}")
59+
60+
try:
61+
# Load config
62+
with open(config_file, "r") as f:
63+
config = json.load(f)
64+
65+
# Load model class dynamically
66+
spec = importlib.util.spec_from_file_location(model_name, model_file)
67+
module = importlib.util.module_from_spec(spec)
68+
spec.loader.exec_module(module)
69+
70+
# Find model class (must match model_name exactly)
71+
if not hasattr(module, model_name):
72+
raise RuntimeError(f"Model class '{model_name}' not found in {model_file}")
73+
74+
model_class = getattr(module, model_name)
75+
if not (isinstance(model_class, type) and hasattr(model_class, "forward")):
76+
raise RuntimeError(f"'{model_name}' in {model_file} is not a valid model class")
77+
78+
models.append({"name": model_name, "class": model_class, "config": config})
79+
logger.info(f"Loaded model: {model_name}")
80+
81+
except Exception as e:
82+
raise RuntimeError(f"Failed to load model {model_name}: {e}")
83+
84+
if filter is not None and len(models) == 0:
85+
raise ValueError(f"No models found matching filter: {filter}")
86+
87+
return models
88+
89+
90+
class ModelSuite:
91+
"""Model Suite for end-to-end model testing."""
92+
93+
def __init__(
94+
self,
95+
name: str = "model",
96+
filter: Optional[List[str]] = None,
97+
):
98+
"""Initialize ModelSuite.
99+
100+
Args:
101+
name: Suite name (default: "model")
102+
filter: Optional list of model names to load
103+
"""
104+
models_dir = os.path.join(os.path.dirname(__file__), "models")
105+
106+
# Load models
107+
models = load_models(models_dir=models_dir, filter=filter)
108+
logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}")
109+
110+
# Store loaded models
111+
self.models = models
112+
self.name = name

0 commit comments

Comments
 (0)