From daebfc1ae585f9fe37de77a001a0a101ff8520e9 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 2 Oct 2025 08:26:59 +0000 Subject: [PATCH 1/6] [ModelSuite] Add Toy Models Summary: Here we introduce model suite (model.py). The idea here to start and codify the ideas from jiannanWang/BackendBenchExamples. Specifically this PR adds some example models / configs which are to be loaded + a Readme. (It may be useful to look at the PR above this as well since it's the model loading logic). This PR adds two toy models to model suite SmokeTestModel - This is simple model that uses aten.ops.mm as we can implement a correct version of this op ToyCoreOpsModel - This is a model which explicitly calls the backwards passes which are both in torchbench + core. Test Plan: the test infra is in the pr above, so tests passing on the PR above should be sufficient here ### Future work with Model Suite https://github.com/meta-pytorch/BackendBench/issues/181 --- BackendBench/suite/models/README.md | 80 +++++++++++++++++ .../models/SmokeTestModel/SmokeTestModel.json | 25 ++++++ .../models/SmokeTestModel/SmokeTestModel.py | 68 +++++++++++++++ .../ToyCoreOpsModel/ToyCoreOpsModel.json | 34 ++++++++ .../models/ToyCoreOpsModel/ToyCoreOpsModel.py | 87 +++++++++++++++++++ 5 files changed, 294 insertions(+) create mode 100644 BackendBench/suite/models/README.md create mode 100644 BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json create mode 100644 BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py create mode 100644 BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json create mode 100644 BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py diff --git a/BackendBench/suite/models/README.md b/BackendBench/suite/models/README.md new file mode 100644 index 00000000..57e707dc --- /dev/null +++ b/BackendBench/suite/models/README.md @@ -0,0 +1,80 @@ +# Adding Models to BackendBench + +## Quick Start + +Models define operator lists and validate that custom backends work correctly in full model execution. Two files required: + +``` +BackendBench/suite/models/YourModel/ +├── YourModel.py # nn.Module class +└── YourModel.json # Configuration +``` + +**Naming rule:** Directory name = File name = Class name (exact match, case-sensitive) + +## Adding a Model + +### 1. Create Directory and Files + +```bash +cd BackendBench/suite/models +mkdir MyModel +cd MyModel +touch MyModel.py MyModel.json +``` + +### 2. Write Model Class (`MyModel.py`) + +**Requirements:** +- Class name = filename (exact match) +- All `__init__` params need defaults +- Add a main() / runner if you are inclined for sanity checking +- Keep it simple - focus on specific operators you're testing +- Look in this directory for examples + +### 3. Write Config (`MyModel.json`) + +**Key Fields:** +- `model_config.init_args` - Args for `__init__()`, must match your defaults +- `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten..default"`) +- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` The format is further described [here](https://huggingface.co/datasets/GPUMODE/backendbench_tests#serialized-arguments-in-backendbench) + - Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc. +- `metadata.description` - What this model tests +- Look in this directory for examples + +**Finding operator names:** +```python +from torch.profiler import profile, ProfilerActivity + +with profile(activities=[ProfilerActivity.CPU]) as prof: + output = model(x) + loss = output.sum() + loss.backward() + +for event in prof.key_averages(): + if "aten::" in event.key: + print(event.key) +``` + +### 4. Test Your Model + +```bash +# Test standalone +cd BackendBench/suite/models/MyModel +python MyModel.py # Add main() for standalone testing + +# Test with suite +python -m BackendBench.scripts.main \ + --suite model \ + --backend aten \ + --model-filter MyModel + +# Expected output: +# Model: MyModel +# Status: ✓ Passed (2/2 tests) +# ✓ small +# ✓ large +``` + +### 5: Validation +`test/test_model_ops_configs.py` and `test/test_model_ops_coverage.py` are tests that validate that all models are loadable / formatted correctly. diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json new file mode 100644 index 00000000..b7d286ae --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json @@ -0,0 +1,25 @@ +{ + "model_config": { + "init_args": { + "input_dim": 128, + "hidden_dim": 128, + "output_dim": 128 + } + }, + "ops": { + "forward": [ + "aten.mm.default" + ], + "backward": [ + "aten.mm.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 128], f32)})", + "medium_batch": "([], {'x': T([16, 128], f32)})", + "large_batch": "([], {'x': T([32, 128], f32)})" + }, + "metadata": { + "description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes" + } +} diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py new file mode 100644 index 00000000..3bf627e4 --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple model that tests matrix multiplication operations using explicit +torch.mm calls. +""" + +import torch +import torch.nn as nn + + +class SmokeTestModel(nn.Module): + """ + Model that uses explicit torch.mm operations to test aten.mm.default + in forward/backward. + """ + + def __init__( + self, + input_dim: int = 128, + hidden_dim: int = 128, + output_dim: int = 128, + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + self.weight1 = nn.Parameter(torch.randn(input_dim, hidden_dim)) + self.weight2 = nn.Parameter(torch.randn(hidden_dim, output_dim)) + self.bias1 = nn.Parameter(torch.randn(hidden_dim)) + self.bias2 = nn.Parameter(torch.randn(output_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass: (x @ weight1 + bias1) -> relu -> (x @ weight2 + bias2) + """ + x = torch.mm(x, self.weight1) + self.bias1 + x = torch.relu(x) + x = torch.mm(x, self.weight2) + self.bias2 + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = SmokeTestModel(input_dim=128, hidden_dim=128, output_dim=128) + batch_size = 4 + input_tensor = torch.randn(batch_size, 128, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + + +if __name__ == "__main__": + main() diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json new file mode 100644 index 00000000..1586273e --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json @@ -0,0 +1,34 @@ +{ + "model_config": { + "init_args": { + "in_channels": 3, + "hidden_channels": 32, + "out_channels": 8, + "num_groups": 8 + } + }, + "ops": { + "forward": [ + "aten.convolution.default", + "aten.native_group_norm.default", + "aten.max_pool2d_with_indices.default", + "aten.avg_pool2d.default", + "aten._adaptive_avg_pool2d.default" + ], + "backward": [ + "aten.convolution_backward.default", + "aten.native_group_norm_backward.default", + "aten.max_pool2d_with_indices_backward.default", + "aten.avg_pool2d_backward.default", + "aten._adaptive_avg_pool2d_backward.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 3, 32, 32], f32)})", + "medium_batch": "([], {'x': T([4, 3, 64, 64], f32)})", + "large_input": "([], {'x': T([2, 3, 128, 128], f32)})" + }, + "metadata": { + "description": "Core operations model testing fundamental operators: convolution, group norm, max pool with indices, avg pool, adaptive avg pool" + } +} \ No newline at end of file diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py new file mode 100644 index 00000000..410e4c4f --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +CNN model that triggers core PyTorch backward operators: +- convolution_backward +- native_group_norm_backward +- max_pool2d_with_indices_backward +- avg_pool2d_backward +- _adaptive_avg_pool2d_backward +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ToyCoreOpsModel(nn.Module): + """CNN that uses conv, group norm, max pool, avg pool, and adaptive avg pool.""" + + def __init__( + self, + in_channels: int = 3, + hidden_channels: int = 32, + out_channels: int = 8, + num_groups: int = 8, + ): + super().__init__() + + if hidden_channels % num_groups != 0: + raise ValueError( + f"hidden_channels ({hidden_channels}) must be divisible by " + f"num_groups ({num_groups})" + ) + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.num_groups = num_groups + + self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm1 = nn.GroupNorm(num_groups, hidden_channels) + self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm2 = nn.GroupNorm(num_groups, hidden_channels) + self.conv_out = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through: Conv->GroupNorm->ReLU->MaxPool->Conv-> + GroupNorm->ReLU->AvgPool->AdaptiveAvgPool->Conv + Output is always (batch, out_channels, 4, 4) regardless of + input size. + """ + x = F.relu(self.group_norm1(self.conv1(x))) + x, _ = F.max_pool2d(x, kernel_size=2, return_indices=True) + x = F.relu(self.group_norm2(self.conv2(x))) + x = F.avg_pool2d(x, kernel_size=2) + x = F.adaptive_avg_pool2d(x, output_size=(4, 4)) + x = self.conv_out(x) + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = ToyCoreOpsModel(in_channels=3, hidden_channels=32, out_channels=8, num_groups=8) + batch_size = 2 + input_tensor = torch.randn(batch_size, 3, 64, 64, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + return model + + +if __name__ == "__main__": + main() From 6e6334b6bd3d7d2cb3c0db57e828588b9a55f222 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 2 Oct 2025 08:27:51 +0000 Subject: [PATCH 2/6] [ModelSuite] Add model loading infrastructure ### Model Registration This PR creates a way of adding models to the suite and automatically validates them through CI. It also loads the models as well. The way these models are added is detailed in this readme. The tl;dir is we use a format similar to kernelbench and SakanaAI/robust-kbench where we pair model code with a config. Importantly the configs contain initialization code, forward pass arguments (both in a similar format to torchbench), and a list of ops in the forward and backwards passes. These ops are fairly important as they are what we want to point out to the researcher when they are optimizing a model. There is a README.md to help folks setup proper model code / configs. We also further verify these registrations are correct through CI. Specifically we run test/test_model_ops_configs.py to ensure the configs are formatted correctly. ### Small Things - Added a --model-filter to the CLI as it will be needed to support filtering in model suite as it chooses things to test based on the model not set of ops ### Testing New tests are added so pytest resolves things here ### Future work with Model Suite https://github.com/meta-pytorch/BackendBench/issues/181 --- BackendBench/scripts/main.py | 26 +++- BackendBench/suite/__init__.py | 2 + BackendBench/suite/model.py | 112 +++++++++++++++++ test/test_model_ops_configs.py | 221 +++++++++++++++++++++++++++++++++ test/test_model_suite.py | 53 ++++++++ 5 files changed, 412 insertions(+), 2 deletions(-) create mode 100644 BackendBench/suite/model.py create mode 100644 test/test_model_ops_configs.py create mode 100644 test/test_model_suite.py diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 479e5805..2240e2d1 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -19,6 +19,7 @@ from BackendBench.output import save_results from BackendBench.suite import ( FactoTestSuite, + ModelSuite, OpInfoTestSuite, SmokeTestSuite, TorchBenchTestSuite, @@ -50,7 +51,7 @@ def setup_logging(log_level): @click.option( "--suite", default="smoke", - type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]), + type=click.Choice(["smoke", "opinfo", "torchbench", "facto", "model"]), help="Which suite to run", ) @click.option( @@ -63,7 +64,13 @@ def setup_logging(log_level): "--ops", default=None, type=str, - help="Comma-separated list of ops to run", + help="Comma-separated list of ops to run (not supported for model suite)", +) +@click.option( + "--model-filter", + default=None, + type=str, + help="Comma-separated list of models to run (only for model suite)", ) @click.option( "--topn-inputs", @@ -147,6 +154,7 @@ def cli( suite, backend, ops, + model_filter, topn_inputs, llm_attempts, llm_model, @@ -166,9 +174,22 @@ def cli( if check_overhead_dominated_ops: raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite") + if suite == "model": + if ops is not None: + raise ValueError( + "--ops filter is not supported for model suite. Use --model-filter instead" + ) + # remove this in later PR as model suite is supported + raise NotImplementedError("Model suite is not supported yet") + + if suite != "model" and model_filter is not None: + raise ValueError("--model-filter is only supported for model suite") + setup_logging(log_level) if ops: ops = ops.split(",") + if model_filter: + model_filter = model_filter.split(",") suite = { "smoke": lambda: SmokeTestSuite, @@ -191,6 +212,7 @@ def cli( torch.bfloat16, filter=ops, ), + "model": lambda: ModelSuite(filter=model_filter), }[suite]() backend_name = backend diff --git a/BackendBench/suite/__init__.py b/BackendBench/suite/__init__.py index 410a5d6e..e9c332a9 100644 --- a/BackendBench/suite/__init__.py +++ b/BackendBench/suite/__init__.py @@ -15,6 +15,7 @@ from .base import OpTest, Test, TestSuite from .facto import FactoTestSuite +from .model import ModelSuite from .opinfo import OpInfoTestSuite from .smoke import randn, SmokeTestSuite from .torchbench import TorchBenchOpTest, TorchBenchTestSuite @@ -24,6 +25,7 @@ "OpTest", "TestSuite", "FactoTestSuite", + "ModelSuite", "OpInfoTestSuite", "SmokeTestSuite", "randn", diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py new file mode 100644 index 00000000..12f98256 --- /dev/null +++ b/BackendBench/suite/model.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Model Suite for testing models defined in configs. +""" + +import importlib.util +import json +import logging +import os +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +def load_models( + models_dir: str = "models", filter: Optional[List[str]] = None +) -> List[Dict[str, Any]]: + """Load models using strict naming convention: folder_name/folder_name.py + folder_name.json + + Args: + models_dir: Directory containing models (default: "models") + filter: Optional list of model names to load. If None, loads all models. + + Returns: + List of dictionaries with keys: + - name: Model name (str) + - class: Model class (type) + - config: Configuration dictionary from JSON file + """ + models = [] + + if not os.path.exists(models_dir): + raise FileNotFoundError(f"Models directory not found: {models_dir}") + + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + # Skip if not in filter + if filter is not None and model_name not in filter: + continue + + # Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json + model_file = os.path.join(model_dir, f"{model_name}.py") + config_file = os.path.join(model_dir, f"{model_name}.json") + + # Check both files exist + if not os.path.exists(model_file): + raise FileNotFoundError(f"Model file not found: {model_file}") + + if not os.path.exists(config_file): + raise FileNotFoundError(f"Config file not found: {config_file}") + + try: + # Load config + with open(config_file, "r") as f: + config = json.load(f) + + # Load model class dynamically + spec = importlib.util.spec_from_file_location(model_name, model_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find model class (must match model_name exactly) + if not hasattr(module, model_name): + raise RuntimeError(f"Model class '{model_name}' not found in {model_file}") + + model_class = getattr(module, model_name) + if not (isinstance(model_class, type) and hasattr(model_class, "forward")): + raise RuntimeError(f"'{model_name}' in {model_file} is not a valid model class") + + models.append({"name": model_name, "class": model_class, "config": config}) + logger.info(f"Loaded model: {model_name}") + + except Exception as e: + raise RuntimeError(f"Failed to load model {model_name}: {e}") + + if filter is not None and len(models) == 0: + raise ValueError(f"No models found matching filter: {filter}") + + return models + + +class ModelSuite: + """Model Suite for end-to-end model testing.""" + + def __init__( + self, + name: str = "model", + filter: Optional[List[str]] = None, + ): + """Initialize ModelSuite. + + Args: + name: Suite name (default: "model") + filter: Optional list of model names to load + """ + models_dir = os.path.join(os.path.dirname(__file__), "models") + + # Load models + models = load_models(models_dir=models_dir, filter=filter) + logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}") + + # Store loaded models + self.models = models + self.name = name diff --git a/test/test_model_ops_configs.py b/test/test_model_ops_configs.py new file mode 100644 index 00000000..8b3f3dbe --- /dev/null +++ b/test/test_model_ops_configs.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unit test to verify that ModelSuite's operator filter correctly matches +the operators defined in model configs. + +This test validates that: +1. load_models correctly loads model configs from the models directory +2. load_model_ops extracts the correct set of operators from model configs +3. TorchBenchTestSuite initialized with those operators has matching optests +4. JSON config files have proper format with required fields +""" + +import json +import os +import unittest +from typing import Any, Dict, List, Set + +from BackendBench.suite.model import load_models +from BackendBench.suite.torchbench import TorchBenchTestSuite + + +def load_model_ops(models: List[Dict[str, Any]]) -> Set[str]: + """Extract unique set of operators from model configs. + + Args: + models: List of model dictionaries with 'name', 'class', and 'config' keys + + Returns: + Set of operator names defined across all model configs + """ + model_ops = set() + for model in models: + config_ops = model["config"].get("ops") + if not config_ops: + raise ValueError(f"Model {model['name']} has no 'ops' field in config") + assert "forward" in config_ops, f"Model {model['name']} has no 'forward' field in config" + assert "backward" in config_ops, f"Model {model['name']} has no 'backward' field in config" + ops_list = config_ops["forward"] + config_ops["backward"] + + model_ops.update(ops_list) + return model_ops + + +class TestModelOpsConfigs(unittest.TestCase): + """Test that model ops filter correctly initializes TorchBenchTestSuite.""" + + def test_model_ops_match_suite_optests(self): + """Test that suite's optests match the operators from model configs.""" + # Get the models directory path (same as ModelSuite does) + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load models using load_models + models = load_models(models_dir=models_dir) + + # Verify we loaded at least one model + self.assertGreater(len(models), 0, "Should load at least one model") + + # Extract operators from model configs using load_model_ops + model_ops = load_model_ops(models) + + # Verify we have operators + self.assertGreater(len(model_ops), 0, "Should have at least one operator") + + # Create filter list from model ops + ops_filter = list(model_ops) + + # Initialize TorchBenchTestSuite with the filter + suite = TorchBenchTestSuite( + name="test_model_ops", + filename=None, # Use default HuggingFace dataset + filter=ops_filter, + topn=None, + ) + + # Get the set of operators in the suite's optests + suite_ops = set(suite.optests.keys()) + + # The suite_ops should be a subset of model_ops because: + # - model_ops is the filter we requested + # - suite_ops contains only those operators that exist in the TorchBench dataset + # - Not all operators in model configs may be in the dataset + self.assertTrue( + suite_ops.issubset(model_ops), + f"Suite operators {suite_ops} should be subset of model operators {model_ops}", + ) + + # Verify that suite actually has some operators + self.assertGreater( + len(suite_ops), 0, "Suite should contain at least one operator from model configs" + ) + + def test_json_configs_have_required_fields(self): + """Test that all JSON config files have proper format with required fields.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load all models + models = load_models(models_dir=models_dir) + + for model in models: + model_name = model["name"] + config = model["config"] + + # Check required top-level fields + self.assertIn("ops", config, f"Model {model_name}: config must have 'ops' field") + self.assertIn( + "model_tests", config, f"Model {model_name}: config must have 'model_tests' field" + ) + + # Validate 'ops' field - can be list or dict + config_ops = config["ops"] + self.assertGreater( + len(config_ops["forward"] + config_ops["backward"]), + 0, + f"Model {model_name}: 'ops' list must not be empty", + ) + for op in config_ops["forward"] + config_ops["backward"]: + self.assertIsInstance( + op, str, f"Model {model_name}: each op in 'ops' must be a string" + ) + self.assertIsInstance( + config_ops["forward"], + list, + f"Model {model_name}: 'ops.forward' must be a list", + ) + for op in config_ops["forward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.forward' must be a string", + ) + self.assertIsInstance( + config_ops["backward"], + list, + f"Model {model_name}: 'ops.backward' must be a list", + ) + for op in config_ops["backward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.backward' must be a string", + ) + + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + for test_name, test_args in config["model_tests"].items(): + self.assertIsInstance( + test_name, str, f"Model {model_name}: test names must be strings" + ) + self.assertIsInstance( + test_args, str, f"Model {model_name}: test args must be strings" + ) + + # Check optional but recommended fields + if "model_config" in config: + self.assertIsInstance( + config["model_config"], + dict, + f"Model {model_name}: 'model_config' must be a dictionary if present", + ) + + def test_json_files_are_valid_json(self): + """Test that all JSON config files are valid JSON and can be parsed.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Find all JSON files in the models directory + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + json_file = os.path.join(model_dir, f"{model_name}.json") + if not os.path.exists(json_file): + continue + + # Try to parse the JSON file + with open(json_file, "r") as f: + try: + config = json.load(f) + self.assertIsInstance( + config, + dict, + f"JSON file {json_file} must contain a dictionary at top level", + ) + except json.JSONDecodeError as e: + self.fail(f"JSON file {json_file} is not valid JSON: {e}") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_model_suite.py b/test/test_model_suite.py new file mode 100644 index 00000000..12ddaf1b --- /dev/null +++ b/test/test_model_suite.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for Model Suite: Filtered TorchBench operators from model tracing + +This test suite validates: +1. Model loading from toy_models directory +2. Operator extraction via model tracing +3. ModelSuite creates filtered TorchBench suite +""" + +import logging +import unittest + +from BackendBench.suite.model import load_models + +# Setup logging +logging.basicConfig(level=logging.WARNING) + + +class TestModelLoading(unittest.TestCase): + """Test model loading functionality.""" + + def test_load_models(self): + """Test that models can be loaded from directory.""" + models = load_models(models_dir="BackendBench/suite/models") + self.assertGreater(len(models), 0, "Should load at least one model") + + # Verify model structure + for model in models: + self.assertIn("name", model) + self.assertIn("class", model) + self.assertIn("config", model) + + def test_load_specific_model(self): + """Test loading a specific model by name.""" + models = load_models(models_dir="BackendBench/suite/models", filter=["ToyCoreOpsModel"]) + self.assertEqual(len(models), 1) + self.assertEqual(models[0]["name"], "ToyCoreOpsModel") + + def test_invalid_filter(self): + """Test that invalid filter raises error.""" + with self.assertRaises(ValueError): + load_models(models_dir="BackendBench/suite/models", filter=["nonexistent"]) + + +if __name__ == "__main__": + # Run tests + unittest.main(verbosity=2) From ab23f8b4dfc017abbef2afe232516e691fe0c074 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 2 Oct 2025 08:27:51 +0000 Subject: [PATCH 3/6] [ModelSuite] Add model ops coverage validation test This PR adds another unit test to the model loading / config system in the last PR. Specifically, here we ensure that the ops specified in the config are run in the model itself. This is important as updates to torch could change how backwards passes could work. Furthermore, if we are expecting folks to write kernels for a set of ops and then run the model, we should guarentee those ops are used. ### Future work with Model Suite https://github.com/meta-pytorch/BackendBench/issues/181 --- test/test_model_ops_coverage.py | 201 ++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 test/test_model_ops_coverage.py diff --git a/test/test_model_ops_coverage.py b/test/test_model_ops_coverage.py new file mode 100644 index 00000000..99d8301b --- /dev/null +++ b/test/test_model_ops_coverage.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unit test to verify that models actually invoke all operators declared in their configs. + +This test validates that: +1. Forward pass invokes all operators in config["ops"]["forward"] +2. Backward pass invokes all operators in config["ops"]["backward"] +3. Clear error messages indicate which operators are missing per model +""" + +import os +import re +import unittest +from typing import Dict, Set + +import torch + +from BackendBench.suite.model import load_models + + +class OpTracker: + """Track operators called during forward/backward passes using torch profiler.""" + + def __init__(self): + self.called_ops: Set[str] = set() + self.profiler = None + + def __enter__(self): + self.called_ops.clear() + + # Use torch profiler to track ops + self.profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + record_shapes=False, + with_stack=False, + ) + self.profiler.__enter__() + return self + + def __exit__(self, *args): + self.profiler.__exit__(*args) + + # Extract op names from profiler events + for event in self.profiler.events(): + event_name = event.name + # Look for aten operations + if "::" in event_name: + # Handle format like "aten::convolution" or "aten::convolution.default" + parts = event_name.replace("::", ".").split(".") + + if len(parts) >= 2 and parts[0] == "aten": + if len(parts) == 2: + # No variant specified, add .default + op_name = f"{parts[0]}.{parts[1]}.default" + else: + # Keep as is + op_name = event_name.replace("::", ".") + + self.called_ops.add(op_name) + + +class TestModelOpsCoverage(unittest.TestCase): + """Test that models invoke all operators declared in their configs.""" + + def test_all_models_ops_coverage(self): + """Test that all models invoke their declared forward and backward ops.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "BackendBench", + "suite", + "models", + ) + + models = load_models(models_dir=models_dir) + self.assertGreater(len(models), 0, "Should load at least one model") + + failures = [] + + for model_dict in models: + model_name = model_dict["name"] + model_class = model_dict["class"] + config = model_dict["config"] + + # Get expected ops from config + config_ops = config.get("ops", {}) + expected_forward = set(config_ops.get("forward", [])) + expected_backward = set(config_ops.get("backward", [])) + + # Skip if no ops to check + if not expected_forward and not expected_backward: + continue + + try: + # Initialize model + model_config = config.get("model_config", {}) + init_args = model_config.get("init_args", {}) + + if model_config.get("requires_init_seed"): + torch.manual_seed(42) + + model = model_class(**init_args) + + # Get a test input from model_tests + model_tests = config.get("model_tests", {}) + if not model_tests: + failures.append(f"{model_name}: No model_tests in config") + continue + + # Use first test case + test_name = list(model_tests.keys())[0] + test_args_str = model_tests[test_name] + + # Parse test args (simple eval for now) + # Format: "([], {'x': T([2, 3, 32, 32], f32)})" + test_input = self._create_test_input_from_string(test_args_str) + + # Track forward pass + tracker = OpTracker() + with tracker: + output = model(**test_input) + + forward_ops = tracker.called_ops + + # Check forward ops coverage + missing_forward = expected_forward - forward_ops + if missing_forward: + failures.append( + f"{model_name} [FORWARD]: Missing ops: {sorted(missing_forward)}" + ) + + # Track backward pass + if expected_backward: + # Ensure output requires grad + for param in model.parameters(): + param.requires_grad = True + + # Create loss + if isinstance(output, torch.Tensor): + loss = output.sum() + else: + # Handle tuple/dict outputs + loss = sum(v.sum() for v in output.values() if isinstance(v, torch.Tensor)) + + tracker_backward = OpTracker() + with tracker_backward: + loss.backward() + + backward_ops = tracker_backward.called_ops + + # Check backward ops coverage + missing_backward = expected_backward - backward_ops + if missing_backward: + failures.append( + f"{model_name} [BACKWARD]: Missing ops: {sorted(missing_backward)}" + ) + + except Exception as e: + failures.append(f"{model_name}: Error during test: {e}") + + # Report all failures at once + if failures: + error_msg = "\n\nOperator Coverage Failures:\n" + "\n".join( + f" - {failure}" for failure in failures + ) + self.fail(error_msg) + + def _create_test_input_from_string(self, test_args_str: str) -> Dict[str, torch.Tensor]: + """Parse test input string into actual tensors. + + Format: "([], {'x': T([2, 3, 32, 32], f32)})" + """ + + # Extract tensor specs: T([shape], dtype) + tensor_pattern = r"'(\w+)':\s*T\(\[([\d,\s]+)\],\s*(\w+)\)" + matches = re.findall(tensor_pattern, test_args_str) + + inputs = {} + for name, shape_str, dtype_str in matches: + shape = [int(x.strip()) for x in shape_str.split(",")] + + # Map dtype string to torch dtype + dtype_map = { + "f32": torch.float32, + "f64": torch.float64, + "i32": torch.int32, + "i64": torch.int64, + } + dtype = dtype_map.get(dtype_str, torch.float32) + + inputs[name] = torch.randn(*shape, dtype=dtype) + + return inputs + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 4409ee278d5ad9755a8deefd0e3e0c9a65db52b4 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 2 Oct 2025 08:27:51 +0000 Subject: [PATCH 4/6] [Model Suite] Add model correctness testing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds end to end model correctness testing testing to Model suite by comparing the outputs and gradients (after a backwards pass) with 1 iteration of the model. We also integrate it into CI. ### Testing Running `uv run python BackendBench/scripts/main.py --suite model --backend directory` with a working mm kernel and a watermarked kernel for everything else yeilds ```bash [2025-10-02 07:16:13][INFO][main.py] ============================================================ [2025-10-02 07:16:13][INFO][main.py] MODEL EVALUATION RESULTS [2025-10-02 07:16:13][INFO][main.py] ============================================================ [2025-10-02 07:16:13][INFO][model.py] Model: ToyCoreOpsModel [2025-10-02 07:16:13][INFO][model.py] Status: ✗ Failed (0/3 tests) [2025-10-02 07:16:13][INFO][model.py] ✗ small_batch [2025-10-02 07:16:13][INFO][model.py] Error: Model ToyCoreOpsModel::small_batch failed: Expected number of channels in input to be divisible by num_groups, but got input of shape [2, 3, 32, 32] and num_groups=8 [2025-10-02 07:16:13][INFO][model.py] ✗ medium_batch [2025-10-02 07:16:13][INFO][model.py] Error: Model ToyCoreOpsModel::medium_batch failed: Expected number of channels in input to be divisible by num_groups, but got input of shape [4, 3, 64, 64] and num_groups=8 [2025-10-02 07:16:13][INFO][model.py] ✗ large_input [2025-10-02 07:16:13][INFO][model.py] Error: Model ToyCoreOpsModel::large_input failed: Expected number of channels in input to be divisible by num_groups, but got input of shape [2, 3, 128, 128] and num_groups=8 [2025-10-02 07:16:13][INFO][model.py] Model: SmokeTestModel [2025-10-02 07:16:13][INFO][model.py] Status: ✓ Passed (3/3 tests) [2025-10-02 07:16:13][INFO][model.py] ✓ small_batch [2025-10-02 07:16:13][INFO][model.py] Output match: ✓ Gradients match: ✓ (4 gradients) [2025-10-02 07:16:13][INFO][model.py] ✓ medium_batch [2025-10-02 07:16:13][INFO][model.py] Output match: ✓ Gradients match: ✓ (4 gradients) [2025-10-02 07:16:13][INFO][model.py] ✓ large_batch [2025-10-02 07:16:13][INFO][model.py] Output match: ✓ Gradients match: ✓ (4 gradients) [2025-10-02 07:16:13][INFO][main.py] ============================================================ ``` ### Future work with Model Suite https://github.com/meta-pytorch/BackendBench/issues/181 --- BackendBench/eval_model.py | 241 +++++++++++++++++++++++++++++++++++ BackendBench/scripts/main.py | 22 +++- BackendBench/suite/model.py | 95 ++++++++++++++ 3 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 BackendBench/eval_model.py diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py new file mode 100644 index 00000000..3cdcc534 --- /dev/null +++ b/BackendBench/eval_model.py @@ -0,0 +1,241 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""Model-level evaluation utilities for testing full model correctness.""" + +import logging +import random +import traceback +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + +import BackendBench +from BackendBench.eval import allclose +from BackendBench.utils import deserialize_args + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelCorrectnessTestResult: + """Result from testing a model configuration.""" + + model_name: str + test_name: str + is_correct: bool = False + error_msg: str = "" + error_type: str = "" + traceback: str = "" + output_match: bool = False + gradients_match: bool = False + num_gradients: int = 0 + + +def eval_model_correctness_test( + model_name: str, + model_class: type, + model_config: Dict[str, Any], + test_name: str, + test_args: str, + kernel_dir: str = None, + atol: float = 1e-2, + rtol: float = 1e-2, +) -> ModelCorrectnessTestResult: + """Evaluate model correctness by comparing eager vs backend execution. + + Similar to eval_correctness_test in eval.py, but for full models instead of individual ops. + + Args: + model_name: Name of the model being tested + model_class: Model class to instantiate + model_config: Model configuration dict with init_args + test_name: Name of this test configuration + test_args: Serialized arguments string for forward pass + kernel_dir: Optional directory containing kernels for backend + atol: Absolute tolerance for allclose + rtol: Relative tolerance for allclose + + Returns: + ModelCorrectnessTestResult with detailed comparison results + """ + try: + # Generate a single seed to use for both eager and backend runs + # This ensures both runs use the same model initialization + seed = random.randint(0, 2**32 - 1) + + # Run in eager mode (reference) + eager_out, eager_grads = _run_model( + model_class, + model_config, + test_args, + backend_enabled=False, + kernel_dir=None, + seed=seed, + ) + + # Run with backend (implementation) + backend_out, backend_grads = _run_model( + model_class, + model_config, + test_args, + backend_enabled=True, + kernel_dir=kernel_dir, + seed=seed, + ) + + # Compare outputs + output_match = allclose(eager_out, backend_out, atol=atol, rtol=rtol) + + # Compare gradients + gradients_match = True + if len(eager_grads) != len(backend_grads): + gradients_match = False + else: + for eager_grad, backend_grad in zip(eager_grads, backend_grads): + if not allclose(eager_grad, backend_grad, atol=atol, rtol=rtol): + gradients_match = False + break + + is_correct = output_match and gradients_match + + return ModelCorrectnessTestResult( + model_name=model_name, + test_name=test_name, + is_correct=is_correct, + output_match=output_match, + gradients_match=gradients_match, + num_gradients=len(eager_grads), + ) + + except Exception as e: + error_msg = f"Model {model_name}::{test_name} failed: {e}" + logger.error(error_msg) + return ModelCorrectnessTestResult( + model_name=model_name, + test_name=test_name, + is_correct=False, + error_msg=error_msg, + error_type=str(type(e)), + traceback=traceback.format_exc(), + ) + + +def _move_model_to_input_device( + model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any] +) -> torch.nn.Module: + """Move model to the same device as input tensor. + + Args: + model: Model to move + args: Positional arguments list + kwargs: Keyword arguments dict + + Returns: + Model on input device (or original model if no input tensor found) + """ + + # this is specific to our configs atm, we should generalize this + input_tensor = kwargs["x"] + if input_tensor is not None: + device = input_tensor.device + model = model.to(device) + return model + + +def _collect_gradients( + model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any] +) -> List[torch.Tensor]: + """Collect gradients from input and model parameters. + + Args: + model: Model with computed gradients + args: Positional arguments list + kwargs: Keyword arguments dict + + Returns: + List of gradient tensors [input_grad, param1_grad, ...] + """ + grads = [] + + # Input gradient - check both args and kwargs + input_grad = None + if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None: + input_grad = args[0].grad + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None: + input_grad = kwargs["x"].grad + + if input_grad is not None: + grads.append(input_grad.clone()) + + # Parameter gradients + for param in model.parameters(): + if param.grad is not None: + grads.append(param.grad.clone()) + + return grads + + +def _run_model( + model_class: type, + model_config: Dict[str, Any], + test_args: str, + backend_enabled: bool, + kernel_dir: str = "generated_kernels", + seed: int = None, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Run model with or without backend enabled. + + Args: + model_class: Model class to instantiate + model_config: Model configuration dict with init_args + test_args: Serialized arguments string for forward pass + backend_enabled: If True, use BackendBench context manager + kernel_dir: Optional directory containing kernels + seed: Random seed for reproducibility. If None, generates a random seed. + + Returns: + Tuple of (output, gradients) where: + - output: Model output tensor (detached) + - gradients: List of gradient tensors [input_grad, param1_grad, ...] + """ + + # Generate seed dynamically and set for deterministic behavior + # IMPORTANT: Must set seed BEFORE deserializing args, because deserialization + # may create random tensors! + if seed is None: + seed = random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + + # Deserialize test arguments (now uses the seed we just set) + args, kwargs = deserialize_args(test_args) + + # Extract model initialization args + init_args = model_config.get("init_args", {}).copy() + + # Create fresh model instance + model = model_class(**init_args) + model.train() + + # Move model to same device as input + model = _move_model_to_input_device(model, args, kwargs) + ctx = ( + BackendBench.BackendBench.enable(kernel_dir=kernel_dir) + if backend_enabled + else nullcontext() + ) + # Run forward + backward with or without backend + with ctx: + output = model(*args, **kwargs) + loss = output.sum() + loss.backward() + + # Collect gradients + grads = _collect_gradients(model, args, kwargs) + + return output.detach(), grads diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 2240e2d1..39f6c3aa 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -41,6 +41,21 @@ def setup_logging(log_level): ) +# Helper function as model suite gets fleshed out +def _test_full_models(suite, backend): + assert suite.name == "model" + all_results = [] + for model in suite.models: + results = suite.eval_model(model, backend) + all_results.append(results) + logger.info("=" * 60) + logger.info("MODEL EVALUATION RESULTS") + logger.info("=" * 60) + for result in all_results: + suite.print_results(result) + logger.info("=" * 60) + + @click.command() @click.option( "--log-level", @@ -179,8 +194,6 @@ def cli( raise ValueError( "--ops filter is not supported for model suite. Use --model-filter instead" ) - # remove this in later PR as model suite is supported - raise NotImplementedError("Model suite is not supported yet") if suite != "model" and model_filter is not None: raise ValueError("--model-filter is only supported for model suite") @@ -246,6 +259,11 @@ def cli( timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") log_dir = f"backendbench_output_{timestamp}" + if suite.name == "model": + _test_full_models(suite, backend) + # currently model suite does not support op testing so now we're done + return + overall_correctness = [] overall_performance = [] all_correctness_results = [] diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index 12f98256..39cbd094 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -14,6 +14,8 @@ import os from typing import Any, Dict, List, Optional +from BackendBench.eval_model import eval_model_correctness_test + logger = logging.getLogger(__name__) @@ -110,3 +112,96 @@ def __init__( # Store loaded models self.models = models self.name = name + + def eval_model(self, model_dict: Dict[str, Any], backend) -> Dict[str, Any]: + """Run evaluation on a single model. + + Args: + model_dict: Dictionary with keys 'name', 'class', 'config' + backend: Backend to use for evaluation + + Returns: + Dictionary with evaluation results including correctness and performance + """ + + model_class = model_dict["class"] + model_name = model_dict["name"] + config = model_dict["config"] + + # Extract model configuration and tests + model_config = config.get("model_config", {}) + model_tests = config.get("model_tests", {}) + + if not model_tests: + return { + "model_name": model_name, + "passed": False, + "error": "No model_tests found in config", + "test_results": [], + } + + # Get kernel_dir from backend if available + kernel_dir = getattr(backend, "ops_dir", None) + + # Run each test + test_results = [] + for test_name, test_args in model_tests.items(): + result = eval_model_correctness_test( + model_name=model_name, + model_class=model_class, + model_config=model_config, + test_name=test_name, + test_args=test_args, + kernel_dir=kernel_dir, + ) + test_results.append(result) + + # Aggregate results + all_passed = all(r.is_correct for r in test_results) + num_passed = sum(1 for r in test_results if r.is_correct) + num_total = len(test_results) + + return { + "model_name": model_name, + "passed": all_passed, + "num_passed": num_passed, + "num_total": num_total, + "test_results": test_results, + } + + def print_results(self, results: Dict[str, Any]) -> None: + """Print model evaluation results. + + Args: + results: Dictionary with evaluation results from eval_model + """ + model_name = results.get("model_name", "Unknown") + passed = results.get("passed", False) + num_passed = results.get("num_passed", 0) + num_total = results.get("num_total", 0) + + logger.info(f"\nModel: {model_name}") + logger.info( + f"Status: {'✓ Passed' if passed else '✗ Failed'} ({num_passed}/{num_total} tests)" + ) + + # Print details for each test + test_results = results.get("test_results", []) + for result in test_results: + status = "✓" if result.is_correct else "✗" + logger.info(f" {status} {result.test_name}") + + if not result.is_correct: + if result.error_msg: + logger.info(f" Error: {result.error_msg}") + else: + # Show what failed + if not result.output_match: + logger.info(" Output mismatch") + if not result.gradients_match: + logger.info(f" Gradient mismatch ({result.num_gradients} gradients)") + else: + # Show success details + logger.info( + f" Output match: ✓ Gradients match: ✓ ({result.num_gradients} gradients)" + ) From dda5832c5bbc45e8920a13dbd33fc300d33b246d Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 2 Oct 2025 08:27:51 +0000 Subject: [PATCH 5/6] [ModelSuite] Refactor TorchBench for ModelSuite inheritance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR integrates operator benchmarking into the Model Suite by having it inherit from TorchBenchTestSuite. The suite now extracts operator lists from model configs and benchmarks those operators using TorchBench data before running end-to-end model tests. This approach aligns with the core goal of BackendBench: testing operators. The Model Suite is designed with the assumption that for a given set of ops, users can provide kernel implementations, and the suite will benchmark both the individual ops and the full model using those implementations. The long-term vision is to make this process seamless—allowing users to run both operator and model benchmarking with a single command. TorchBench is used here because it provides the strongest guarantee that running the suite benchmarks all operators required for a specific model configuration. Its dataset is easily extensible and includes realistic tensor shapes derived from actual models. The main design drawback is that this integration makes supporting kernel fusions with models more complex. However, it is preferable to handle kernel fusions in a separate suite regardless. ### Testing Running `uv run python BackendBench/scripts/main.py --suite model --backend directory --topn 1` with a working mm kernel and other kernels being watermakred yeilds the expected result (below) ```bash Successfully registered 36 custom operators [2025-10-02 07:21:23][INFO][main.py] ============================================================ [2025-10-02 07:21:23][INFO][main.py] MODEL EVALUATION RESULTS [2025-10-02 07:21:23][INFO][main.py] ============================================================ [2025-10-02 07:21:23][INFO][model.py] Model: ToyCoreOpsModel [2025-10-02 07:21:23][INFO][model.py] Status: ✗ Failed (0/3 tests) [2025-10-02 07:21:23][INFO][model.py] ✗ small_batch [2025-10-02 07:21:23][INFO][model.py] Error: Model ToyCoreOpsModel::small_batch failed: Expected number of channels in input to be divisible by num_groups, but got input of shape [2, 3, 32, 32] and num_groups=8 [2025-10-02 07:21:23][INFO][model.py] ✗ medium_batch [2025-10-02 07:21:23][INFO][model.py] Error: Model ToyCoreOpsModel::medium_batch failed: Expected number of channels in input to be divisible by num_groups, but got input of shape [4, 3, 64, 64] and num_groups=8 [2025-10-02 07:21:23][INFO][model.py] ✗ large_input [2025-10-02 07:21:23][INFO][model.py] Error: Model ToyCoreOpsModel::large_input failed: Expected number of channels in input to be divisible by num_groups, but got input of shape [2, 3, 128, 128] and num_groups=8 [2025-10-02 07:21:23][INFO][model.py] Model: SmokeTestModel [2025-10-02 07:21:23][INFO][model.py] Status: ✓ Passed (3/3 tests) [2025-10-02 07:21:23][INFO][model.py] ✓ small_batch [2025-10-02 07:21:23][INFO][model.py] Output match: ✓ Gradients match: ✓ (4 gradients) [2025-10-02 07:21:23][INFO][model.py] ✓ medium_batch [2025-10-02 07:21:23][INFO][model.py] Output match: ✓ Gradients match: ✓ (4 gradients) [2025-10-02 07:21:23][INFO][model.py] ✓ large_batch [2025-10-02 07:21:23][INFO][model.py] Output match: ✓ Gradients match: ✓ (4 gradients) [2025-10-02 07:21:23][INFO][main.py] ============================================================ [2025-10-02 07:21:23][INFO][output.py] Full results saved to generated_kernels/full_results.json [2025-10-02 07:21:23][INFO][output.py] Operator summary CSV saved to generated_kernels/operator_summary.csv [2025-10-02 07:21:23][INFO][output.py] Failed operations log saved to generated_kernels/failed_tests.json [2025-10-02 07:21:23][INFO][output.py] Overall summary saved to generated_kernels/OVERALL_SUMMARY.md [2025-10-02 07:21:23][INFO][output.py] Results saved to directory: /home/dev/sapling_repos/BackendBench/generated_kernels Results saved to directory: /home/dev/sapling_repos/BackendBench/generated_kernels Overall summary saved to: /home/dev/sapling_repos/BackendBench/generated_kernels/OVERALL_SUMMARY.md ``` ### Future work with Model Suite https://github.com/meta-pytorch/BackendBench/issues/181 --- BackendBench/scripts/main.py | 16 ++++++------- BackendBench/suite/model.py | 41 +++++++++++++++++++++++++++----- BackendBench/suite/torchbench.py | 8 +++++++ 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 39f6c3aa..19b7c561 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -184,8 +184,6 @@ def cli( p, ): if suite != "torchbench": - if topn_inputs is not None: - raise ValueError("topn-inputs is only supported for torchbench suite") if check_overhead_dominated_ops: raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite") @@ -198,6 +196,10 @@ def cli( if suite != "model" and model_filter is not None: raise ValueError("--model-filter is only supported for model suite") + if suite != "model" and suite != "torchbench": + if topn_inputs is not None: + raise ValueError("topn-inputs is only supported for torchbench suite") + setup_logging(log_level) if ops: ops = ops.split(",") @@ -225,7 +227,7 @@ def cli( torch.bfloat16, filter=ops, ), - "model": lambda: ModelSuite(filter=model_filter), + "model": lambda: ModelSuite(filter=model_filter, topn=topn_inputs), }[suite]() backend_name = backend @@ -259,11 +261,6 @@ def cli( timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") log_dir = f"backendbench_output_{timestamp}" - if suite.name == "model": - _test_full_models(suite, backend) - # currently model suite does not support op testing so now we're done - return - overall_correctness = [] overall_performance = [] all_correctness_results = [] @@ -332,6 +329,9 @@ def cli( f"perf@p score (rate of correct samples with a speedup greater than p, p={p}): {perf_at_p_score:.2f}" ) + if suite.name == "model": + _test_full_models(suite, backend) + command = "python -m BackendBench.scripts.main " + " ".join(sys.argv[1:]) # Save results if not disabled diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index 39cbd094..7bc40511 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -5,7 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -Model Suite for testing models defined in configs. +Model Suite for testing operators defined in toy model configs. + +This suite extends TorchBenchTestSuite by reading operator lists from +model configs, validating they exist in the TorchBench dataset, then +filtering to include only those operators. """ import importlib.util @@ -16,6 +20,8 @@ from BackendBench.eval_model import eval_model_correctness_test +from .torchbench import TorchBenchTestSuite + logger = logging.getLogger(__name__) @@ -89,29 +95,52 @@ def load_models( return models -class ModelSuite: - """Model Suite for end-to-end model testing.""" +class ModelSuite(TorchBenchTestSuite): + """Model Suite that filters TorchBench operators based on model configs. + + This suite reads operator lists from model configs, validates they exist + in the TorchBench dataset, then creates a filtered suite containing only + those operators. + """ def __init__( self, name: str = "model", filter: Optional[List[str]] = None, + topn: Optional[int] = None, ): """Initialize ModelSuite. Args: name: Suite name (default: "model") filter: Optional list of model names to load + topn: Optional limit on number of tests per operator """ models_dir = os.path.join(os.path.dirname(__file__), "models") # Load models models = load_models(models_dir=models_dir, filter=filter) logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}") - - # Store loaded models + model_ops = self.get_model_ops(models) + filter = list(model_ops) + # Store loaded models for evaluation self.models = models - self.name = name + + self._initialize_torchbench_suite(name, None, filter, topn, False) + + def get_model_ops(self, models: List[Dict[str, Any]]) -> List[str]: + # Extract operators from model configs + model_ops = set() + for model in models: + config_ops = model["config"]["ops"] + ops_list = config_ops["forward"] + ops_list.extend(config_ops["backward"]) + + model_ops.update(ops_list) + logger.info(f"Model {model['name']}: {len(ops_list)} operators defined in config") + + logger.info(f"ModelSuite: Total {len(model_ops)} unique operators across all models") + return model_ops def eval_model(self, model_dict: Dict[str, Any], backend) -> Dict[str, Any]: """Run evaluation on a single model. diff --git a/BackendBench/suite/torchbench.py b/BackendBench/suite/torchbench.py index 2ee3d698..b116d6b8 100644 --- a/BackendBench/suite/torchbench.py +++ b/BackendBench/suite/torchbench.py @@ -78,6 +78,13 @@ def __init__( filter=None, topn=None, check_overhead_dominated_ops=False, + ): + self._initialize_torchbench_suite( + name, filename, filter, topn, check_overhead_dominated_ops + ) + + def _initialize_torchbench_suite( + self, name, filename, filter, topn, check_overhead_dominated_ops ): self.name = name self.topn = topn @@ -87,6 +94,7 @@ def __init__( format="auto", # Auto-detect based on file extension filter=filter, ) + if check_overhead_dominated_ops: # Only include ops which are overhead dominated (this is useful as a performance canary) ops_list = [op for op in ops_list if op.get("is_overhead_dominated_op", False)] From 8598f5b5e398b69c447bad5600a2fd5a9ce62068 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 2 Oct 2025 10:43:38 +0000 Subject: [PATCH 6/6] [WIP] [ModelSuite] Add Performace Testing Summary: Test Plan: --- BackendBench/eval_model.py | 137 +++++++++++++++++++++++++++++++++++ BackendBench/scripts/main.py | 24 +++++- BackendBench/suite/model.py | 96 ++++++++++++++++++------ 3 files changed, 232 insertions(+), 25 deletions(-) diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py index 3cdcc534..81a63653 100644 --- a/BackendBench/eval_model.py +++ b/BackendBench/eval_model.py @@ -37,6 +37,20 @@ class ModelCorrectnessTestResult: num_gradients: int = 0 +@dataclass +class ModelPerformanceTestResult: + """Result from benchmarking a model configuration.""" + + model_name: str + test_name: str + speedup: float + eager_time_ms: float + backend_time_ms: float + error_msg: str = "" + successfully_ran: bool = False + test_type: str = "performance" + + def eval_model_correctness_test( model_name: str, model_class: type, @@ -239,3 +253,126 @@ def _run_model( grads = _collect_gradients(model, args, kwargs) return output.detach(), grads + + +def _get_bench_function(): + """Get appropriate benchmarking function based on hardware availability.""" + try: + if torch.cuda.is_available(): + import triton.testing + + return triton.testing.do_bench + except ImportError: + pass + + # Fall back to CPU benchmarking + from BackendBench.eval import cpu_bench + + return cpu_bench + + +def eval_model_performance_test( + model_name: str, + model_class: type, + model_config: Dict[str, Any], + test_name: str, + test_args: str, + kernel_dir: str = None, + atol: float = 1e-2, + rtol: float = 1e-2, +) -> ModelPerformanceTestResult: + """Benchmark model performance comparing eager vs backend execution. + + Similar to eval_performance in eval.py, but for full models instead of individual ops. + + Args: + model_name: Name of the model being tested + model_class: Model class to instantiate + model_config: Model configuration dict with init_args + test_name: Name of this test configuration + test_args: Serialized arguments string for forward pass + kernel_dir: Optional directory containing kernels for backend + atol: Absolute tolerance for allclose + rtol: Relative tolerance for allclose + + Returns: + ModelPerformanceTestResult with timing and speedup information + """ + try: + # 1. Choose benchmarking function (CUDA vs CPU) + bench_fn = _get_bench_function() + + # 2. Generate seed for reproducibility + seed = random.randint(0, 2**32 - 1) + + # 3. First verify correctness (don't benchmark incorrect implementations!) + correctness_result = eval_model_correctness_test( + model_name=model_name, + model_class=model_class, + model_config=model_config, + test_name=test_name, + test_args=test_args, + kernel_dir=kernel_dir, + atol=atol, + rtol=rtol, + ) + + if not correctness_result.is_correct: + return ModelPerformanceTestResult( + model_name=model_name, + test_name=test_name, + speedup=0.0, + eager_time_ms=0.0, + backend_time_ms=0.0, + successfully_ran=False, + error_msg=f"Correctness check failed: {correctness_result.error_msg}", + ) + + # 4. Benchmark eager mode + eager_time = bench_fn( + lambda: _run_model( + model_class, + model_config, + test_args, + backend_enabled=False, + kernel_dir=None, + seed=seed, + ) + ) + + # 5. Benchmark backend mode + backend_time = bench_fn( + lambda: _run_model( + model_class, + model_config, + test_args, + backend_enabled=True, + kernel_dir=kernel_dir, + seed=seed, + ) + ) + + # 6. Calculate speedup (eager_time / backend_time) + speedup = eager_time / backend_time if backend_time > 0 else 0.0 + + return ModelPerformanceTestResult( + model_name=model_name, + test_name=test_name, + speedup=speedup, + eager_time_ms=eager_time * 1000, # Convert to ms + backend_time_ms=backend_time * 1000, + successfully_ran=True, + ) + + except Exception as e: + error_msg = f"Model {model_name}::{test_name} benchmark failed: {e}" + logger.error(error_msg) + return ModelPerformanceTestResult( + model_name=model_name, + test_name=test_name, + speedup=0.0, + eager_time_ms=0.0, + backend_time_ms=0.0, + successfully_ran=False, + error_msg=error_msg, + ) diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 19b7c561..4994149a 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -43,11 +43,13 @@ def setup_logging(log_level): # Helper function as model suite gets fleshed out def _test_full_models(suite, backend): + """Test full model evaluation including both correctness and performance.""" assert suite.name == "model" all_results = [] for model in suite.models: results = suite.eval_model(model, backend) all_results.append(results) + logger.info("=" * 60) logger.info("MODEL EVALUATION RESULTS") logger.info("=" * 60) @@ -55,6 +57,20 @@ def _test_full_models(suite, backend): suite.print_results(result) logger.info("=" * 60) + # Calculate overall model suite performance + all_speedups = [] + for result in all_results: + for perf in result.get("performance_results", []): + if perf.successfully_ran: + all_speedups.append(perf.speedup) + + if all_speedups: + overall_speedup = torch.tensor(all_speedups).log().mean().exp().item() + logger.info(f"Overall Model Suite Performance: {overall_speedup:.2f}x geomean speedup") + logger.info("=" * 60) + + return all_results + @click.command() @click.option( @@ -330,7 +346,13 @@ def cli( ) if suite.name == "model": - _test_full_models(suite, backend) + model_results = _test_full_models(suite, backend) + # Extract performance results for saving + for result in model_results: + for perf in result.get("performance_results", []): + all_performance_results.append(perf) + for corr in result.get("correctness_results", []): + all_correctness_results.append(corr) command = "python -m BackendBench.scripts.main " + " ".join(sys.argv[1:]) diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index 7bc40511..4579d221 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -18,7 +18,12 @@ import os from typing import Any, Dict, List, Optional -from BackendBench.eval_model import eval_model_correctness_test +import torch + +from BackendBench.eval_model import ( + eval_model_correctness_test, + eval_model_performance_test, +) from .torchbench import TorchBenchTestSuite @@ -166,16 +171,20 @@ def eval_model(self, model_dict: Dict[str, Any], backend) -> Dict[str, Any]: "model_name": model_name, "passed": False, "error": "No model_tests found in config", - "test_results": [], + "correctness_results": [], + "performance_results": [], } # Get kernel_dir from backend if available kernel_dir = getattr(backend, "ops_dir", None) - # Run each test - test_results = [] + # Run both correctness and performance tests + correctness_results = [] + performance_results = [] + for test_name, test_args in model_tests.items(): - result = eval_model_correctness_test( + # Correctness test + corr_result = eval_model_correctness_test( model_name=model_name, model_class=model_class, model_config=model_config, @@ -183,19 +192,36 @@ def eval_model(self, model_dict: Dict[str, Any], backend) -> Dict[str, Any]: test_args=test_args, kernel_dir=kernel_dir, ) - test_results.append(result) + correctness_results.append(corr_result) + + # Performance test + perf_result = eval_model_performance_test( + model_name=model_name, + model_class=model_class, + model_config=model_config, + test_name=test_name, + test_args=test_args, + kernel_dir=kernel_dir, + ) + performance_results.append(perf_result) # Aggregate results - all_passed = all(r.is_correct for r in test_results) - num_passed = sum(1 for r in test_results if r.is_correct) - num_total = len(test_results) + all_passed = all(r.is_correct for r in correctness_results) + num_passed = sum(1 for r in correctness_results if r.is_correct) + num_total = len(correctness_results) + + # Calculate average speedup (geometric mean like operators) + speedups = [r.speedup for r in performance_results if r.successfully_ran] + avg_speedup = torch.tensor(speedups).log().mean().exp().item() if speedups else 0.0 return { "model_name": model_name, "passed": all_passed, "num_passed": num_passed, "num_total": num_total, - "test_results": test_results, + "correctness_results": correctness_results, + "performance_results": performance_results, + "avg_speedup": avg_speedup, } def print_results(self, results: Dict[str, Any]) -> None: @@ -208,29 +234,51 @@ def print_results(self, results: Dict[str, Any]) -> None: passed = results.get("passed", False) num_passed = results.get("num_passed", 0) num_total = results.get("num_total", 0) + avg_speedup = results.get("avg_speedup", 0.0) logger.info(f"\nModel: {model_name}") logger.info( f"Status: {'✓ Passed' if passed else '✗ Failed'} ({num_passed}/{num_total} tests)" ) + logger.info(f"Performance: {avg_speedup:.2f}x average speedup") # Print details for each test - test_results = results.get("test_results", []) - for result in test_results: - status = "✓" if result.is_correct else "✗" - logger.info(f" {status} {result.test_name}") - - if not result.is_correct: - if result.error_msg: - logger.info(f" Error: {result.error_msg}") + correctness_results = results.get("correctness_results", []) + performance_results = results.get("performance_results", []) + + # Handle backward compatibility + if not correctness_results: + correctness_results = results.get("test_results", []) + + for corr_result in correctness_results: + # Find corresponding performance result + perf_result = None + for p in performance_results: + if p.test_name == corr_result.test_name: + perf_result = p + break + + status = "✓" if corr_result.is_correct else "✗" + speedup_str = "" + if perf_result and perf_result.successfully_ran: + speedup_str = f" ({perf_result.speedup:.2f}x speedup)" + + logger.info(f" {status} {corr_result.test_name}{speedup_str}") + + if not corr_result.is_correct: + if corr_result.error_msg: + logger.info(f" Error: {corr_result.error_msg}") else: # Show what failed - if not result.output_match: + if not corr_result.output_match: logger.info(" Output mismatch") - if not result.gradients_match: - logger.info(f" Gradient mismatch ({result.num_gradients} gradients)") + if not corr_result.gradients_match: + logger.info( + f" Gradient mismatch ({corr_result.num_gradients} gradients)" + ) else: # Show success details - logger.info( - f" Output match: ✓ Gradients match: ✓ ({result.num_gradients} gradients)" - ) + details = f" Output match: ✓ Gradients match: ✓ ({corr_result.num_gradients} gradients)" + if perf_result and perf_result.successfully_ran: + details += f" Time: eager={perf_result.eager_time_ms:.2f}ms, backend={perf_result.backend_time_ms:.2f}ms" + logger.info(details)