diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py new file mode 100644 index 00000000..3cdcc534 --- /dev/null +++ b/BackendBench/eval_model.py @@ -0,0 +1,241 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""Model-level evaluation utilities for testing full model correctness.""" + +import logging +import random +import traceback +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + +import BackendBench +from BackendBench.eval import allclose +from BackendBench.utils import deserialize_args + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelCorrectnessTestResult: + """Result from testing a model configuration.""" + + model_name: str + test_name: str + is_correct: bool = False + error_msg: str = "" + error_type: str = "" + traceback: str = "" + output_match: bool = False + gradients_match: bool = False + num_gradients: int = 0 + + +def eval_model_correctness_test( + model_name: str, + model_class: type, + model_config: Dict[str, Any], + test_name: str, + test_args: str, + kernel_dir: str = None, + atol: float = 1e-2, + rtol: float = 1e-2, +) -> ModelCorrectnessTestResult: + """Evaluate model correctness by comparing eager vs backend execution. + + Similar to eval_correctness_test in eval.py, but for full models instead of individual ops. + + Args: + model_name: Name of the model being tested + model_class: Model class to instantiate + model_config: Model configuration dict with init_args + test_name: Name of this test configuration + test_args: Serialized arguments string for forward pass + kernel_dir: Optional directory containing kernels for backend + atol: Absolute tolerance for allclose + rtol: Relative tolerance for allclose + + Returns: + ModelCorrectnessTestResult with detailed comparison results + """ + try: + # Generate a single seed to use for both eager and backend runs + # This ensures both runs use the same model initialization + seed = random.randint(0, 2**32 - 1) + + # Run in eager mode (reference) + eager_out, eager_grads = _run_model( + model_class, + model_config, + test_args, + backend_enabled=False, + kernel_dir=None, + seed=seed, + ) + + # Run with backend (implementation) + backend_out, backend_grads = _run_model( + model_class, + model_config, + test_args, + backend_enabled=True, + kernel_dir=kernel_dir, + seed=seed, + ) + + # Compare outputs + output_match = allclose(eager_out, backend_out, atol=atol, rtol=rtol) + + # Compare gradients + gradients_match = True + if len(eager_grads) != len(backend_grads): + gradients_match = False + else: + for eager_grad, backend_grad in zip(eager_grads, backend_grads): + if not allclose(eager_grad, backend_grad, atol=atol, rtol=rtol): + gradients_match = False + break + + is_correct = output_match and gradients_match + + return ModelCorrectnessTestResult( + model_name=model_name, + test_name=test_name, + is_correct=is_correct, + output_match=output_match, + gradients_match=gradients_match, + num_gradients=len(eager_grads), + ) + + except Exception as e: + error_msg = f"Model {model_name}::{test_name} failed: {e}" + logger.error(error_msg) + return ModelCorrectnessTestResult( + model_name=model_name, + test_name=test_name, + is_correct=False, + error_msg=error_msg, + error_type=str(type(e)), + traceback=traceback.format_exc(), + ) + + +def _move_model_to_input_device( + model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any] +) -> torch.nn.Module: + """Move model to the same device as input tensor. + + Args: + model: Model to move + args: Positional arguments list + kwargs: Keyword arguments dict + + Returns: + Model on input device (or original model if no input tensor found) + """ + + # this is specific to our configs atm, we should generalize this + input_tensor = kwargs["x"] + if input_tensor is not None: + device = input_tensor.device + model = model.to(device) + return model + + +def _collect_gradients( + model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any] +) -> List[torch.Tensor]: + """Collect gradients from input and model parameters. + + Args: + model: Model with computed gradients + args: Positional arguments list + kwargs: Keyword arguments dict + + Returns: + List of gradient tensors [input_grad, param1_grad, ...] + """ + grads = [] + + # Input gradient - check both args and kwargs + input_grad = None + if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None: + input_grad = args[0].grad + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None: + input_grad = kwargs["x"].grad + + if input_grad is not None: + grads.append(input_grad.clone()) + + # Parameter gradients + for param in model.parameters(): + if param.grad is not None: + grads.append(param.grad.clone()) + + return grads + + +def _run_model( + model_class: type, + model_config: Dict[str, Any], + test_args: str, + backend_enabled: bool, + kernel_dir: str = "generated_kernels", + seed: int = None, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Run model with or without backend enabled. + + Args: + model_class: Model class to instantiate + model_config: Model configuration dict with init_args + test_args: Serialized arguments string for forward pass + backend_enabled: If True, use BackendBench context manager + kernel_dir: Optional directory containing kernels + seed: Random seed for reproducibility. If None, generates a random seed. + + Returns: + Tuple of (output, gradients) where: + - output: Model output tensor (detached) + - gradients: List of gradient tensors [input_grad, param1_grad, ...] + """ + + # Generate seed dynamically and set for deterministic behavior + # IMPORTANT: Must set seed BEFORE deserializing args, because deserialization + # may create random tensors! + if seed is None: + seed = random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + + # Deserialize test arguments (now uses the seed we just set) + args, kwargs = deserialize_args(test_args) + + # Extract model initialization args + init_args = model_config.get("init_args", {}).copy() + + # Create fresh model instance + model = model_class(**init_args) + model.train() + + # Move model to same device as input + model = _move_model_to_input_device(model, args, kwargs) + ctx = ( + BackendBench.BackendBench.enable(kernel_dir=kernel_dir) + if backend_enabled + else nullcontext() + ) + # Run forward + backward with or without backend + with ctx: + output = model(*args, **kwargs) + loss = output.sum() + loss.backward() + + # Collect gradients + grads = _collect_gradients(model, args, kwargs) + + return output.detach(), grads diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 479e5805..39f6c3aa 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -19,6 +19,7 @@ from BackendBench.output import save_results from BackendBench.suite import ( FactoTestSuite, + ModelSuite, OpInfoTestSuite, SmokeTestSuite, TorchBenchTestSuite, @@ -40,6 +41,21 @@ def setup_logging(log_level): ) +# Helper function as model suite gets fleshed out +def _test_full_models(suite, backend): + assert suite.name == "model" + all_results = [] + for model in suite.models: + results = suite.eval_model(model, backend) + all_results.append(results) + logger.info("=" * 60) + logger.info("MODEL EVALUATION RESULTS") + logger.info("=" * 60) + for result in all_results: + suite.print_results(result) + logger.info("=" * 60) + + @click.command() @click.option( "--log-level", @@ -50,7 +66,7 @@ def setup_logging(log_level): @click.option( "--suite", default="smoke", - type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]), + type=click.Choice(["smoke", "opinfo", "torchbench", "facto", "model"]), help="Which suite to run", ) @click.option( @@ -63,7 +79,13 @@ def setup_logging(log_level): "--ops", default=None, type=str, - help="Comma-separated list of ops to run", + help="Comma-separated list of ops to run (not supported for model suite)", +) +@click.option( + "--model-filter", + default=None, + type=str, + help="Comma-separated list of models to run (only for model suite)", ) @click.option( "--topn-inputs", @@ -147,6 +169,7 @@ def cli( suite, backend, ops, + model_filter, topn_inputs, llm_attempts, llm_model, @@ -166,9 +189,20 @@ def cli( if check_overhead_dominated_ops: raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite") + if suite == "model": + if ops is not None: + raise ValueError( + "--ops filter is not supported for model suite. Use --model-filter instead" + ) + + if suite != "model" and model_filter is not None: + raise ValueError("--model-filter is only supported for model suite") + setup_logging(log_level) if ops: ops = ops.split(",") + if model_filter: + model_filter = model_filter.split(",") suite = { "smoke": lambda: SmokeTestSuite, @@ -191,6 +225,7 @@ def cli( torch.bfloat16, filter=ops, ), + "model": lambda: ModelSuite(filter=model_filter), }[suite]() backend_name = backend @@ -224,6 +259,11 @@ def cli( timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") log_dir = f"backendbench_output_{timestamp}" + if suite.name == "model": + _test_full_models(suite, backend) + # currently model suite does not support op testing so now we're done + return + overall_correctness = [] overall_performance = [] all_correctness_results = [] diff --git a/BackendBench/suite/__init__.py b/BackendBench/suite/__init__.py index 410a5d6e..e9c332a9 100644 --- a/BackendBench/suite/__init__.py +++ b/BackendBench/suite/__init__.py @@ -15,6 +15,7 @@ from .base import OpTest, Test, TestSuite from .facto import FactoTestSuite +from .model import ModelSuite from .opinfo import OpInfoTestSuite from .smoke import randn, SmokeTestSuite from .torchbench import TorchBenchOpTest, TorchBenchTestSuite @@ -24,6 +25,7 @@ "OpTest", "TestSuite", "FactoTestSuite", + "ModelSuite", "OpInfoTestSuite", "SmokeTestSuite", "randn", diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py new file mode 100644 index 00000000..39cbd094 --- /dev/null +++ b/BackendBench/suite/model.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Model Suite for testing models defined in configs. +""" + +import importlib.util +import json +import logging +import os +from typing import Any, Dict, List, Optional + +from BackendBench.eval_model import eval_model_correctness_test + +logger = logging.getLogger(__name__) + + +def load_models( + models_dir: str = "models", filter: Optional[List[str]] = None +) -> List[Dict[str, Any]]: + """Load models using strict naming convention: folder_name/folder_name.py + folder_name.json + + Args: + models_dir: Directory containing models (default: "models") + filter: Optional list of model names to load. If None, loads all models. + + Returns: + List of dictionaries with keys: + - name: Model name (str) + - class: Model class (type) + - config: Configuration dictionary from JSON file + """ + models = [] + + if not os.path.exists(models_dir): + raise FileNotFoundError(f"Models directory not found: {models_dir}") + + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + # Skip if not in filter + if filter is not None and model_name not in filter: + continue + + # Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json + model_file = os.path.join(model_dir, f"{model_name}.py") + config_file = os.path.join(model_dir, f"{model_name}.json") + + # Check both files exist + if not os.path.exists(model_file): + raise FileNotFoundError(f"Model file not found: {model_file}") + + if not os.path.exists(config_file): + raise FileNotFoundError(f"Config file not found: {config_file}") + + try: + # Load config + with open(config_file, "r") as f: + config = json.load(f) + + # Load model class dynamically + spec = importlib.util.spec_from_file_location(model_name, model_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find model class (must match model_name exactly) + if not hasattr(module, model_name): + raise RuntimeError(f"Model class '{model_name}' not found in {model_file}") + + model_class = getattr(module, model_name) + if not (isinstance(model_class, type) and hasattr(model_class, "forward")): + raise RuntimeError(f"'{model_name}' in {model_file} is not a valid model class") + + models.append({"name": model_name, "class": model_class, "config": config}) + logger.info(f"Loaded model: {model_name}") + + except Exception as e: + raise RuntimeError(f"Failed to load model {model_name}: {e}") + + if filter is not None and len(models) == 0: + raise ValueError(f"No models found matching filter: {filter}") + + return models + + +class ModelSuite: + """Model Suite for end-to-end model testing.""" + + def __init__( + self, + name: str = "model", + filter: Optional[List[str]] = None, + ): + """Initialize ModelSuite. + + Args: + name: Suite name (default: "model") + filter: Optional list of model names to load + """ + models_dir = os.path.join(os.path.dirname(__file__), "models") + + # Load models + models = load_models(models_dir=models_dir, filter=filter) + logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}") + + # Store loaded models + self.models = models + self.name = name + + def eval_model(self, model_dict: Dict[str, Any], backend) -> Dict[str, Any]: + """Run evaluation on a single model. + + Args: + model_dict: Dictionary with keys 'name', 'class', 'config' + backend: Backend to use for evaluation + + Returns: + Dictionary with evaluation results including correctness and performance + """ + + model_class = model_dict["class"] + model_name = model_dict["name"] + config = model_dict["config"] + + # Extract model configuration and tests + model_config = config.get("model_config", {}) + model_tests = config.get("model_tests", {}) + + if not model_tests: + return { + "model_name": model_name, + "passed": False, + "error": "No model_tests found in config", + "test_results": [], + } + + # Get kernel_dir from backend if available + kernel_dir = getattr(backend, "ops_dir", None) + + # Run each test + test_results = [] + for test_name, test_args in model_tests.items(): + result = eval_model_correctness_test( + model_name=model_name, + model_class=model_class, + model_config=model_config, + test_name=test_name, + test_args=test_args, + kernel_dir=kernel_dir, + ) + test_results.append(result) + + # Aggregate results + all_passed = all(r.is_correct for r in test_results) + num_passed = sum(1 for r in test_results if r.is_correct) + num_total = len(test_results) + + return { + "model_name": model_name, + "passed": all_passed, + "num_passed": num_passed, + "num_total": num_total, + "test_results": test_results, + } + + def print_results(self, results: Dict[str, Any]) -> None: + """Print model evaluation results. + + Args: + results: Dictionary with evaluation results from eval_model + """ + model_name = results.get("model_name", "Unknown") + passed = results.get("passed", False) + num_passed = results.get("num_passed", 0) + num_total = results.get("num_total", 0) + + logger.info(f"\nModel: {model_name}") + logger.info( + f"Status: {'✓ Passed' if passed else '✗ Failed'} ({num_passed}/{num_total} tests)" + ) + + # Print details for each test + test_results = results.get("test_results", []) + for result in test_results: + status = "✓" if result.is_correct else "✗" + logger.info(f" {status} {result.test_name}") + + if not result.is_correct: + if result.error_msg: + logger.info(f" Error: {result.error_msg}") + else: + # Show what failed + if not result.output_match: + logger.info(" Output mismatch") + if not result.gradients_match: + logger.info(f" Gradient mismatch ({result.num_gradients} gradients)") + else: + # Show success details + logger.info( + f" Output match: ✓ Gradients match: ✓ ({result.num_gradients} gradients)" + ) diff --git a/BackendBench/suite/models/README.md b/BackendBench/suite/models/README.md new file mode 100644 index 00000000..57e707dc --- /dev/null +++ b/BackendBench/suite/models/README.md @@ -0,0 +1,80 @@ +# Adding Models to BackendBench + +## Quick Start + +Models define operator lists and validate that custom backends work correctly in full model execution. Two files required: + +``` +BackendBench/suite/models/YourModel/ +├── YourModel.py # nn.Module class +└── YourModel.json # Configuration +``` + +**Naming rule:** Directory name = File name = Class name (exact match, case-sensitive) + +## Adding a Model + +### 1. Create Directory and Files + +```bash +cd BackendBench/suite/models +mkdir MyModel +cd MyModel +touch MyModel.py MyModel.json +``` + +### 2. Write Model Class (`MyModel.py`) + +**Requirements:** +- Class name = filename (exact match) +- All `__init__` params need defaults +- Add a main() / runner if you are inclined for sanity checking +- Keep it simple - focus on specific operators you're testing +- Look in this directory for examples + +### 3. Write Config (`MyModel.json`) + +**Key Fields:** +- `model_config.init_args` - Args for `__init__()`, must match your defaults +- `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten..default"`) +- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` The format is further described [here](https://huggingface.co/datasets/GPUMODE/backendbench_tests#serialized-arguments-in-backendbench) + - Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc. +- `metadata.description` - What this model tests +- Look in this directory for examples + +**Finding operator names:** +```python +from torch.profiler import profile, ProfilerActivity + +with profile(activities=[ProfilerActivity.CPU]) as prof: + output = model(x) + loss = output.sum() + loss.backward() + +for event in prof.key_averages(): + if "aten::" in event.key: + print(event.key) +``` + +### 4. Test Your Model + +```bash +# Test standalone +cd BackendBench/suite/models/MyModel +python MyModel.py # Add main() for standalone testing + +# Test with suite +python -m BackendBench.scripts.main \ + --suite model \ + --backend aten \ + --model-filter MyModel + +# Expected output: +# Model: MyModel +# Status: ✓ Passed (2/2 tests) +# ✓ small +# ✓ large +``` + +### 5: Validation +`test/test_model_ops_configs.py` and `test/test_model_ops_coverage.py` are tests that validate that all models are loadable / formatted correctly. diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json new file mode 100644 index 00000000..b7d286ae --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json @@ -0,0 +1,25 @@ +{ + "model_config": { + "init_args": { + "input_dim": 128, + "hidden_dim": 128, + "output_dim": 128 + } + }, + "ops": { + "forward": [ + "aten.mm.default" + ], + "backward": [ + "aten.mm.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 128], f32)})", + "medium_batch": "([], {'x': T([16, 128], f32)})", + "large_batch": "([], {'x': T([32, 128], f32)})" + }, + "metadata": { + "description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes" + } +} diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py new file mode 100644 index 00000000..3bf627e4 --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple model that tests matrix multiplication operations using explicit +torch.mm calls. +""" + +import torch +import torch.nn as nn + + +class SmokeTestModel(nn.Module): + """ + Model that uses explicit torch.mm operations to test aten.mm.default + in forward/backward. + """ + + def __init__( + self, + input_dim: int = 128, + hidden_dim: int = 128, + output_dim: int = 128, + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + self.weight1 = nn.Parameter(torch.randn(input_dim, hidden_dim)) + self.weight2 = nn.Parameter(torch.randn(hidden_dim, output_dim)) + self.bias1 = nn.Parameter(torch.randn(hidden_dim)) + self.bias2 = nn.Parameter(torch.randn(output_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass: (x @ weight1 + bias1) -> relu -> (x @ weight2 + bias2) + """ + x = torch.mm(x, self.weight1) + self.bias1 + x = torch.relu(x) + x = torch.mm(x, self.weight2) + self.bias2 + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = SmokeTestModel(input_dim=128, hidden_dim=128, output_dim=128) + batch_size = 4 + input_tensor = torch.randn(batch_size, 128, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + + +if __name__ == "__main__": + main() diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json new file mode 100644 index 00000000..1586273e --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json @@ -0,0 +1,34 @@ +{ + "model_config": { + "init_args": { + "in_channels": 3, + "hidden_channels": 32, + "out_channels": 8, + "num_groups": 8 + } + }, + "ops": { + "forward": [ + "aten.convolution.default", + "aten.native_group_norm.default", + "aten.max_pool2d_with_indices.default", + "aten.avg_pool2d.default", + "aten._adaptive_avg_pool2d.default" + ], + "backward": [ + "aten.convolution_backward.default", + "aten.native_group_norm_backward.default", + "aten.max_pool2d_with_indices_backward.default", + "aten.avg_pool2d_backward.default", + "aten._adaptive_avg_pool2d_backward.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 3, 32, 32], f32)})", + "medium_batch": "([], {'x': T([4, 3, 64, 64], f32)})", + "large_input": "([], {'x': T([2, 3, 128, 128], f32)})" + }, + "metadata": { + "description": "Core operations model testing fundamental operators: convolution, group norm, max pool with indices, avg pool, adaptive avg pool" + } +} \ No newline at end of file diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py new file mode 100644 index 00000000..410e4c4f --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +CNN model that triggers core PyTorch backward operators: +- convolution_backward +- native_group_norm_backward +- max_pool2d_with_indices_backward +- avg_pool2d_backward +- _adaptive_avg_pool2d_backward +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ToyCoreOpsModel(nn.Module): + """CNN that uses conv, group norm, max pool, avg pool, and adaptive avg pool.""" + + def __init__( + self, + in_channels: int = 3, + hidden_channels: int = 32, + out_channels: int = 8, + num_groups: int = 8, + ): + super().__init__() + + if hidden_channels % num_groups != 0: + raise ValueError( + f"hidden_channels ({hidden_channels}) must be divisible by " + f"num_groups ({num_groups})" + ) + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.num_groups = num_groups + + self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm1 = nn.GroupNorm(num_groups, hidden_channels) + self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm2 = nn.GroupNorm(num_groups, hidden_channels) + self.conv_out = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through: Conv->GroupNorm->ReLU->MaxPool->Conv-> + GroupNorm->ReLU->AvgPool->AdaptiveAvgPool->Conv + Output is always (batch, out_channels, 4, 4) regardless of + input size. + """ + x = F.relu(self.group_norm1(self.conv1(x))) + x, _ = F.max_pool2d(x, kernel_size=2, return_indices=True) + x = F.relu(self.group_norm2(self.conv2(x))) + x = F.avg_pool2d(x, kernel_size=2) + x = F.adaptive_avg_pool2d(x, output_size=(4, 4)) + x = self.conv_out(x) + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = ToyCoreOpsModel(in_channels=3, hidden_channels=32, out_channels=8, num_groups=8) + batch_size = 2 + input_tensor = torch.randn(batch_size, 3, 64, 64, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + return model + + +if __name__ == "__main__": + main() diff --git a/test/test_model_ops_configs.py b/test/test_model_ops_configs.py new file mode 100644 index 00000000..8b3f3dbe --- /dev/null +++ b/test/test_model_ops_configs.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unit test to verify that ModelSuite's operator filter correctly matches +the operators defined in model configs. + +This test validates that: +1. load_models correctly loads model configs from the models directory +2. load_model_ops extracts the correct set of operators from model configs +3. TorchBenchTestSuite initialized with those operators has matching optests +4. JSON config files have proper format with required fields +""" + +import json +import os +import unittest +from typing import Any, Dict, List, Set + +from BackendBench.suite.model import load_models +from BackendBench.suite.torchbench import TorchBenchTestSuite + + +def load_model_ops(models: List[Dict[str, Any]]) -> Set[str]: + """Extract unique set of operators from model configs. + + Args: + models: List of model dictionaries with 'name', 'class', and 'config' keys + + Returns: + Set of operator names defined across all model configs + """ + model_ops = set() + for model in models: + config_ops = model["config"].get("ops") + if not config_ops: + raise ValueError(f"Model {model['name']} has no 'ops' field in config") + assert "forward" in config_ops, f"Model {model['name']} has no 'forward' field in config" + assert "backward" in config_ops, f"Model {model['name']} has no 'backward' field in config" + ops_list = config_ops["forward"] + config_ops["backward"] + + model_ops.update(ops_list) + return model_ops + + +class TestModelOpsConfigs(unittest.TestCase): + """Test that model ops filter correctly initializes TorchBenchTestSuite.""" + + def test_model_ops_match_suite_optests(self): + """Test that suite's optests match the operators from model configs.""" + # Get the models directory path (same as ModelSuite does) + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load models using load_models + models = load_models(models_dir=models_dir) + + # Verify we loaded at least one model + self.assertGreater(len(models), 0, "Should load at least one model") + + # Extract operators from model configs using load_model_ops + model_ops = load_model_ops(models) + + # Verify we have operators + self.assertGreater(len(model_ops), 0, "Should have at least one operator") + + # Create filter list from model ops + ops_filter = list(model_ops) + + # Initialize TorchBenchTestSuite with the filter + suite = TorchBenchTestSuite( + name="test_model_ops", + filename=None, # Use default HuggingFace dataset + filter=ops_filter, + topn=None, + ) + + # Get the set of operators in the suite's optests + suite_ops = set(suite.optests.keys()) + + # The suite_ops should be a subset of model_ops because: + # - model_ops is the filter we requested + # - suite_ops contains only those operators that exist in the TorchBench dataset + # - Not all operators in model configs may be in the dataset + self.assertTrue( + suite_ops.issubset(model_ops), + f"Suite operators {suite_ops} should be subset of model operators {model_ops}", + ) + + # Verify that suite actually has some operators + self.assertGreater( + len(suite_ops), 0, "Suite should contain at least one operator from model configs" + ) + + def test_json_configs_have_required_fields(self): + """Test that all JSON config files have proper format with required fields.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load all models + models = load_models(models_dir=models_dir) + + for model in models: + model_name = model["name"] + config = model["config"] + + # Check required top-level fields + self.assertIn("ops", config, f"Model {model_name}: config must have 'ops' field") + self.assertIn( + "model_tests", config, f"Model {model_name}: config must have 'model_tests' field" + ) + + # Validate 'ops' field - can be list or dict + config_ops = config["ops"] + self.assertGreater( + len(config_ops["forward"] + config_ops["backward"]), + 0, + f"Model {model_name}: 'ops' list must not be empty", + ) + for op in config_ops["forward"] + config_ops["backward"]: + self.assertIsInstance( + op, str, f"Model {model_name}: each op in 'ops' must be a string" + ) + self.assertIsInstance( + config_ops["forward"], + list, + f"Model {model_name}: 'ops.forward' must be a list", + ) + for op in config_ops["forward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.forward' must be a string", + ) + self.assertIsInstance( + config_ops["backward"], + list, + f"Model {model_name}: 'ops.backward' must be a list", + ) + for op in config_ops["backward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.backward' must be a string", + ) + + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + for test_name, test_args in config["model_tests"].items(): + self.assertIsInstance( + test_name, str, f"Model {model_name}: test names must be strings" + ) + self.assertIsInstance( + test_args, str, f"Model {model_name}: test args must be strings" + ) + + # Check optional but recommended fields + if "model_config" in config: + self.assertIsInstance( + config["model_config"], + dict, + f"Model {model_name}: 'model_config' must be a dictionary if present", + ) + + def test_json_files_are_valid_json(self): + """Test that all JSON config files are valid JSON and can be parsed.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Find all JSON files in the models directory + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + json_file = os.path.join(model_dir, f"{model_name}.json") + if not os.path.exists(json_file): + continue + + # Try to parse the JSON file + with open(json_file, "r") as f: + try: + config = json.load(f) + self.assertIsInstance( + config, + dict, + f"JSON file {json_file} must contain a dictionary at top level", + ) + except json.JSONDecodeError as e: + self.fail(f"JSON file {json_file} is not valid JSON: {e}") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_model_ops_coverage.py b/test/test_model_ops_coverage.py new file mode 100644 index 00000000..99d8301b --- /dev/null +++ b/test/test_model_ops_coverage.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unit test to verify that models actually invoke all operators declared in their configs. + +This test validates that: +1. Forward pass invokes all operators in config["ops"]["forward"] +2. Backward pass invokes all operators in config["ops"]["backward"] +3. Clear error messages indicate which operators are missing per model +""" + +import os +import re +import unittest +from typing import Dict, Set + +import torch + +from BackendBench.suite.model import load_models + + +class OpTracker: + """Track operators called during forward/backward passes using torch profiler.""" + + def __init__(self): + self.called_ops: Set[str] = set() + self.profiler = None + + def __enter__(self): + self.called_ops.clear() + + # Use torch profiler to track ops + self.profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + record_shapes=False, + with_stack=False, + ) + self.profiler.__enter__() + return self + + def __exit__(self, *args): + self.profiler.__exit__(*args) + + # Extract op names from profiler events + for event in self.profiler.events(): + event_name = event.name + # Look for aten operations + if "::" in event_name: + # Handle format like "aten::convolution" or "aten::convolution.default" + parts = event_name.replace("::", ".").split(".") + + if len(parts) >= 2 and parts[0] == "aten": + if len(parts) == 2: + # No variant specified, add .default + op_name = f"{parts[0]}.{parts[1]}.default" + else: + # Keep as is + op_name = event_name.replace("::", ".") + + self.called_ops.add(op_name) + + +class TestModelOpsCoverage(unittest.TestCase): + """Test that models invoke all operators declared in their configs.""" + + def test_all_models_ops_coverage(self): + """Test that all models invoke their declared forward and backward ops.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "BackendBench", + "suite", + "models", + ) + + models = load_models(models_dir=models_dir) + self.assertGreater(len(models), 0, "Should load at least one model") + + failures = [] + + for model_dict in models: + model_name = model_dict["name"] + model_class = model_dict["class"] + config = model_dict["config"] + + # Get expected ops from config + config_ops = config.get("ops", {}) + expected_forward = set(config_ops.get("forward", [])) + expected_backward = set(config_ops.get("backward", [])) + + # Skip if no ops to check + if not expected_forward and not expected_backward: + continue + + try: + # Initialize model + model_config = config.get("model_config", {}) + init_args = model_config.get("init_args", {}) + + if model_config.get("requires_init_seed"): + torch.manual_seed(42) + + model = model_class(**init_args) + + # Get a test input from model_tests + model_tests = config.get("model_tests", {}) + if not model_tests: + failures.append(f"{model_name}: No model_tests in config") + continue + + # Use first test case + test_name = list(model_tests.keys())[0] + test_args_str = model_tests[test_name] + + # Parse test args (simple eval for now) + # Format: "([], {'x': T([2, 3, 32, 32], f32)})" + test_input = self._create_test_input_from_string(test_args_str) + + # Track forward pass + tracker = OpTracker() + with tracker: + output = model(**test_input) + + forward_ops = tracker.called_ops + + # Check forward ops coverage + missing_forward = expected_forward - forward_ops + if missing_forward: + failures.append( + f"{model_name} [FORWARD]: Missing ops: {sorted(missing_forward)}" + ) + + # Track backward pass + if expected_backward: + # Ensure output requires grad + for param in model.parameters(): + param.requires_grad = True + + # Create loss + if isinstance(output, torch.Tensor): + loss = output.sum() + else: + # Handle tuple/dict outputs + loss = sum(v.sum() for v in output.values() if isinstance(v, torch.Tensor)) + + tracker_backward = OpTracker() + with tracker_backward: + loss.backward() + + backward_ops = tracker_backward.called_ops + + # Check backward ops coverage + missing_backward = expected_backward - backward_ops + if missing_backward: + failures.append( + f"{model_name} [BACKWARD]: Missing ops: {sorted(missing_backward)}" + ) + + except Exception as e: + failures.append(f"{model_name}: Error during test: {e}") + + # Report all failures at once + if failures: + error_msg = "\n\nOperator Coverage Failures:\n" + "\n".join( + f" - {failure}" for failure in failures + ) + self.fail(error_msg) + + def _create_test_input_from_string(self, test_args_str: str) -> Dict[str, torch.Tensor]: + """Parse test input string into actual tensors. + + Format: "([], {'x': T([2, 3, 32, 32], f32)})" + """ + + # Extract tensor specs: T([shape], dtype) + tensor_pattern = r"'(\w+)':\s*T\(\[([\d,\s]+)\],\s*(\w+)\)" + matches = re.findall(tensor_pattern, test_args_str) + + inputs = {} + for name, shape_str, dtype_str in matches: + shape = [int(x.strip()) for x in shape_str.split(",")] + + # Map dtype string to torch dtype + dtype_map = { + "f32": torch.float32, + "f64": torch.float64, + "i32": torch.int32, + "i64": torch.int64, + } + dtype = dtype_map.get(dtype_str, torch.float32) + + inputs[name] = torch.randn(*shape, dtype=dtype) + + return inputs + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_model_suite.py b/test/test_model_suite.py new file mode 100644 index 00000000..12ddaf1b --- /dev/null +++ b/test/test_model_suite.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for Model Suite: Filtered TorchBench operators from model tracing + +This test suite validates: +1. Model loading from toy_models directory +2. Operator extraction via model tracing +3. ModelSuite creates filtered TorchBench suite +""" + +import logging +import unittest + +from BackendBench.suite.model import load_models + +# Setup logging +logging.basicConfig(level=logging.WARNING) + + +class TestModelLoading(unittest.TestCase): + """Test model loading functionality.""" + + def test_load_models(self): + """Test that models can be loaded from directory.""" + models = load_models(models_dir="BackendBench/suite/models") + self.assertGreater(len(models), 0, "Should load at least one model") + + # Verify model structure + for model in models: + self.assertIn("name", model) + self.assertIn("class", model) + self.assertIn("config", model) + + def test_load_specific_model(self): + """Test loading a specific model by name.""" + models = load_models(models_dir="BackendBench/suite/models", filter=["ToyCoreOpsModel"]) + self.assertEqual(len(models), 1) + self.assertEqual(models[0]["name"], "ToyCoreOpsModel") + + def test_invalid_filter(self): + """Test that invalid filter raises error.""" + with self.assertRaises(ValueError): + load_models(models_dir="BackendBench/suite/models", filter=["nonexistent"]) + + +if __name__ == "__main__": + # Run tests + unittest.main(verbosity=2)