From 8aa5f7c998f18b10c40f3d277166bf15ec8211c8 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 30 Sep 2025 07:52:33 +0000 Subject: [PATCH 01/16] Add model backend --- BackendBench/scripts/main.py | 43 +- BackendBench/suite/__init__.py | 2 + BackendBench/suite/model.py | 382 ++++++++++++++++++ .../models/toy_core_ops/toy_core_ops.json | 41 ++ .../suite/models/toy_core_ops/toy_core_ops.py | 233 +++++++++++ 5 files changed, 700 insertions(+), 1 deletion(-) create mode 100644 BackendBench/suite/model.py create mode 100644 BackendBench/suite/models/toy_core_ops/toy_core_ops.json create mode 100644 BackendBench/suite/models/toy_core_ops/toy_core_ops.py diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 479e5805..9422d0db 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -19,6 +19,7 @@ from BackendBench.output import save_results from BackendBench.suite import ( FactoTestSuite, + ModelSuite, OpInfoTestSuite, SmokeTestSuite, TorchBenchTestSuite, @@ -50,7 +51,7 @@ def setup_logging(log_level): @click.option( "--suite", default="smoke", - type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]), + type=click.Choice(["smoke", "opinfo", "torchbench", "facto", "model"]), help="Which suite to run", ) @click.option( @@ -166,6 +167,9 @@ def cli( if check_overhead_dominated_ops: raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite") + if suite == "model" and backend != "directory": + raise ValueError("model suite only supports directory backend") + setup_logging(log_level) if ops: ops = ops.split(",") @@ -191,6 +195,7 @@ def cli( torch.bfloat16, filter=ops, ), + "model": lambda: ModelSuite(filter=ops), }[suite]() backend_name = backend @@ -292,6 +297,42 @@ def cli( f"perf@p score (rate of correct samples with a speedup greater than p, p={p}): {perf_at_p_score:.2f}" ) + # Add full model testing for model suite + if suite.name == "model" and hasattr(suite, 'test_model_correctness'): + print("\n" + "="*80) + print("FULL MODEL TESTING") + print("="*80) + + # Pass ops_directory as kernel_dir for directory backend + kernel_dir = ops_directory if backend_name == "directory" else None + model_results = suite.test_model_correctness(kernel_dir=kernel_dir) + + # Print results + print("\nModel Correctness Results:") + print("-"*80) + total_passed = 0 + total_tests = 0 + for model_name, test_results in model_results.items(): + passed = sum(1 for result in test_results.values() if result) + total = len(test_results) + total_passed += passed + total_tests += total + percentage = (passed / total * 100) if total > 0 else 0 + print(f" {model_name}: {passed}/{total} configs passed ({percentage:.1f}%)") + + # Show individual config results + for config_name, is_correct in test_results.items(): + status = "✓ PASS" if is_correct else "✗ FAIL" + print(f" {config_name}: {status}") + + print("-"*80) + if total_tests > 0: + overall_percentage = total_passed / total_tests * 100 + print(f"\nModel Suite Score: {total_passed}/{total_tests} ({overall_percentage:.1f}%)") + else: + print("\nNo model tests were run") + print("="*80) + command = "python -m BackendBench.scripts.main " + " ".join(sys.argv[1:]) # Save results if not disabled diff --git a/BackendBench/suite/__init__.py b/BackendBench/suite/__init__.py index 410a5d6e..e9c332a9 100644 --- a/BackendBench/suite/__init__.py +++ b/BackendBench/suite/__init__.py @@ -15,6 +15,7 @@ from .base import OpTest, Test, TestSuite from .facto import FactoTestSuite +from .model import ModelSuite from .opinfo import OpInfoTestSuite from .smoke import randn, SmokeTestSuite from .torchbench import TorchBenchOpTest, TorchBenchTestSuite @@ -24,6 +25,7 @@ "OpTest", "TestSuite", "FactoTestSuite", + "ModelSuite", "OpInfoTestSuite", "SmokeTestSuite", "randn", diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py new file mode 100644 index 00000000..d9ca3050 --- /dev/null +++ b/BackendBench/suite/model.py @@ -0,0 +1,382 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Model Suite for testing toy models against backends. + +This suite extends TorchBenchTestSuite to provide two testing approaches: +1. Operator-level testing (via __iter__, inherited infrastructure) +2. Model-level correctness testing (via test_model_correctness, new functionality) +""" + +import json +import os +import importlib.util +import logging +import torch +from typing import Dict, List, Any, Optional + +from BackendBench.utils import get_pytorch_op +from .torchbench import TorchBenchTestSuite, TorchBenchTest + +logger = logging.getLogger(__name__) + + +def load_toy_models(toy_models_dir: str = "models", filter: Optional[List[str]] = None) -> List[Dict[str, Any]]: + """Load models using strict naming convention: folder_name/folder_name.py + folder_name.json + + Args: + toy_models_dir: Directory containing toy models (default: "models") + filter: Optional list of model names to load. If None, loads all models. + + Returns: + List of dictionaries with keys: + - name: Model name (str) + - class: Model class (type) + - config: Configuration dictionary from JSON file + """ + models = [] + + if not os.path.exists(toy_models_dir): + logger.warning(f"Toy models directory not found: {toy_models_dir}") + return models + + for model_name in os.listdir(toy_models_dir): + # Apply filter if specified + if filter is not None and model_name not in filter: + continue + + model_dir = os.path.join(toy_models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + # Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json + model_file = os.path.join(model_dir, f"{model_name}.py") + config_file = os.path.join(model_dir, f"{model_name}.json") + + # Check both files exist + if not os.path.exists(model_file): + logger.warning(f"Model file not found: {model_file}") + continue + + if not os.path.exists(config_file): + logger.warning(f"Config file not found: {config_file}") + continue + + try: + # Load config + with open(config_file, 'r') as f: + config = json.load(f) + + # Load model class dynamically + spec = importlib.util.spec_from_file_location(model_name, model_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find model class (ends with "Model") + model_class = None + for attr_name in dir(module): + attr = getattr(module, attr_name) + if (isinstance(attr, type) and + attr_name.endswith("Model") and + hasattr(attr, "forward")): + model_class = attr + break + + if model_class is None: + logger.error(f"No model class found in {model_file}") + continue + + models.append({ + "name": model_name, + "class": model_class, + "config": config + }) + logger.info(f"Loaded model: {model_name}") + + except Exception as e: + logger.error(f"Failed to load {model_name}: {e}") + continue + + return models + + +def _create_torchbench_op_test(model_name: str, model_class, config: Dict[str, Any], op_name: str): + """Create a TorchBenchOpTest for a specific operator from a toy model. + + Args: + model_name: Name of the model + model_class: Model class + config: Configuration dictionary + op_name: Operator name (e.g., "conv2d", "relu") + + Returns: + TorchBenchOpTest instance compatible with existing evaluation infrastructure + """ + from .torchbench import TorchBenchOpTest + from BackendBench.utils import serialize_args + + # Generate test inputs from model configs + inputs = [] + for test_config in config["test_configs"]: + # Extract input shape from test config + forward_args = test_config["forward_args"] + batch_size = forward_args["batch_size"] + input_shape = forward_args["input_shape"] + + # Create input tensor + full_shape = [batch_size] + input_shape + input_tensor = torch.randn(*full_shape) + + # Serialize the input for TorchBenchOpTest + # TorchBenchOpTest expects serialized inputs (strings) + serialized = serialize_args([input_tensor], {}) + inputs.append(serialized) + + # Create TorchBenchOpTest - it expects op name and inputs + # We need to convert op_name to full torch.ops format + op = get_pytorch_op(op_name) + if op is None: + raise ValueError(f"Could not find PyTorch operation for {op_name}") + + # Get the full op string (e.g., "aten.conv2d.default") + op_str = str(op).replace("torch.ops.", "") + + # Create the test with serialized inputs + return TorchBenchOpTest(op_str, inputs, topn=None) + + +class FullModelTest: + """Complete model forward/backward testing. + + This class handles running a model with a specific test configuration + in both eager mode and backend mode, then comparing the results. + """ + + def __init__(self, model_name: str, model_class, config: Dict[str, Any], test_config: Dict[str, Any]): + """Initialize FullModelTest. + + Args: + model_name: Name of the model being tested + model_class: Model class to instantiate + config: Full model configuration including model_config + test_config: Specific test configuration with forward_args + """ + self.model_name = model_name + self.model_class = model_class + self.config = config + self.test_config = test_config + + def run_with_backend(self, backend_enabled: bool, kernel_dir: str = None) -> tuple: + """Run model with backend enabled or disabled. + + Args: + backend_enabled: If True, use BackendBench context manager to enable backend + kernel_dir: Optional directory containing kernels (for backend mode) + + Returns: + Tuple of (output, gradients) where: + - output: Model output tensor (detached) + - gradients: List of gradient tensors [input_grad, param1_grad, param2_grad, ...] + """ + import BackendBench + + # Extract model configuration + model_config = self.config["model_config"]["init_args"] + + # Extract input configuration + forward_args = self.test_config["forward_args"] + batch_size = forward_args["batch_size"] + input_shape = forward_args["input_shape"] + + # Create full input shape: [batch_size, *input_shape] + full_shape = [batch_size] + input_shape + + # Set seed for deterministic behavior + seed = model_config.get("seed", 42) + torch.manual_seed(seed) + + # Create fresh model instance + model = self.model_class(**model_config) + model.train() + + # Create input tensor with requires_grad for input gradient + x = torch.randn(*full_shape, requires_grad=True) + + # Run forward + backward with or without backend + if backend_enabled: + # Use context manager to enable backend + if kernel_dir is None: + # Default to generated_kernels directory + kernel_dir = os.path.join(os.getcwd(), "generated_kernels") + + with BackendBench.BackendBench.enable(kernel_dir=kernel_dir): + output = model(x) + loss = output.sum() + loss.backward() + else: + # Run in eager mode (no backend) + output = model(x) + loss = output.sum() + loss.backward() + + # Collect gradients: [input_grad, param1_grad, param2_grad, ...] + grads = [] + + # Input gradient + if x.grad is not None: + grads.append(x.grad.clone()) + + # Parameter gradients + for param in model.parameters(): + if param.grad is not None: + grads.append(param.grad.clone()) + + return output.detach(), grads + + def test_correctness(self, atol=1e-6, rtol=1e-5, kernel_dir: str = None) -> bool: + """Test numerical correctness by comparing eager vs backend execution. + + Args: + atol: Absolute tolerance for torch.allclose + rtol: Relative tolerance for torch.allclose + kernel_dir: Optional directory containing kernels + + Returns: + True if eager and backend produce matching results, False otherwise + """ + try: + # Run in eager mode + eager_out, eager_grads = self.run_with_backend(False, kernel_dir=kernel_dir) + + # Run with backend + backend_out, backend_grads = self.run_with_backend(True, kernel_dir=kernel_dir) + + # Compare outputs + if not torch.allclose(eager_out, backend_out, atol=atol, rtol=rtol): + logger.debug(f"{self.model_name}::{self.test_config['name']}: Output mismatch") + return False + + # Compare number of gradients + if len(eager_grads) != len(backend_grads): + logger.debug( + f"{self.model_name}::{self.test_config['name']}: " + f"Gradient count mismatch ({len(eager_grads)} vs {len(backend_grads)})" + ) + return False + + # Compare each gradient + for i, (eager_grad, backend_grad) in enumerate(zip(eager_grads, backend_grads)): + if not torch.allclose(eager_grad, backend_grad, atol=atol, rtol=rtol): + logger.debug( + f"{self.model_name}::{self.test_config['name']}: " + f"Gradient {i} mismatch" + ) + return False + + return True + + except Exception as e: + logger.error(f"{self.model_name}::{self.test_config['name']}: Correctness test failed: {e}") + return False + + +class ModelSuite(TorchBenchTestSuite): + """Model Suite extending TorchBenchTestSuite. + + Provides two testing approaches: + 1. Operator-level testing via __iter__() (inherited infrastructure) + 2. Model-level correctness testing via test_model_correctness() (Model Suite specific) + """ + + def __init__(self, name: str = "model", filter: Optional[List[str]] = None, models_dir: str = None): + """Initialize ModelSuite. + + Args: + name: Suite name (default: "model") + filter: Optional list of model names to test + models_dir: Optional directory for models (default: "BackendBench/suite/models") + """ + # Don't call super().__init__() with parameters since TorchBenchTestSuite + # expects different arguments. Just initialize the base object. + super(TorchBenchTestSuite, self).__init__() + + self.name = name + + # Default to models under suite/models + if models_dir is None: + models_dir = os.path.join(os.path.dirname(__file__), "models") + + self.models = load_toy_models(toy_models_dir=models_dir, filter=filter) + logger.info(f"ModelSuite: {len(self.models)} models loaded from {models_dir}") + + def __iter__(self): + """Yield operator tests from all models (TorchBench approach). + + This method enables operator-level testing using the inherited + TorchBench infrastructure. Returns TorchBenchOpTest instances. + """ + for model in self.models: + # Extract operators from config + if "expected_operators" not in model["config"]: + logger.warning(f"Model {model['name']} has no expected_operators in config") + continue + + expected_ops = model["config"]["expected_operators"] + + # Yield forward pass operators + if "forward_pass" in expected_ops: + for op_name in expected_ops["forward_pass"]: + try: + yield _create_torchbench_op_test(model["name"], model["class"], model["config"], op_name) + except Exception as e: + logger.error(f"Failed to create test for forward op {op_name}: {e}") + + # Yield backward pass operators + if "backward_pass" in expected_ops: + for op_name in expected_ops["backward_pass"]: + try: + yield _create_torchbench_op_test(model["name"], model["class"], model["config"], op_name) + except Exception as e: + logger.error(f"Failed to create test for backward op {op_name}: {e}") + + def test_model_correctness(self, kernel_dir: str = None) -> Dict[str, Dict[str, bool]]: + """Test full model correctness for all models and configurations. + + This method runs each model with each test configuration, comparing + eager mode vs backend mode to verify numerical correctness. + + Args: + kernel_dir: Optional directory containing kernels for backend + + Returns: + Dictionary mapping model_name -> {config_name -> bool} + where bool indicates if the test passed + """ + results = {} + + for model in self.models: + model_results = {} + + # Test each configuration for this model + for test_config in model["config"]["test_configs"]: + test = FullModelTest( + model_name=model["name"], + model_class=model["class"], + config=model["config"], + test_config=test_config + ) + + test_name = test_config["name"] + is_correct = test.test_correctness(kernel_dir=kernel_dir) + model_results[test_name] = is_correct + + status = "PASS" if is_correct else "FAIL" + logger.info(f"{model['name']}::{test_name}: {status}") + + results[model["name"]] = model_results + + return results diff --git a/BackendBench/suite/models/toy_core_ops/toy_core_ops.json b/BackendBench/suite/models/toy_core_ops/toy_core_ops.json new file mode 100644 index 00000000..7cad404e --- /dev/null +++ b/BackendBench/suite/models/toy_core_ops/toy_core_ops.json @@ -0,0 +1,41 @@ +{ + "model_config": { + "init_args": { + "in_channels": 3, + "hidden_channels": 32, + "out_channels": 8, + "num_groups": 8, + "seed": 42 + } + }, + "test_configs": [ + { + "name": "small_batch", + "forward_args": { + "batch_size": 2, + "input_shape": [3, 32, 32] + } + }, + { + "name": "medium_batch", + "forward_args": { + "batch_size": 4, + "input_shape": [3, 64, 64] + } + }, + { + "name": "large_input", + "forward_args": { + "batch_size": 2, + "input_shape": [3, 128, 128] + } + } + ], + "expected_operators": { + "forward_pass": ["conv2d", "native_group_norm", "relu", "max_pool2d_with_indices", "avg_pool2d", "adaptive_avg_pool2d"], + "backward_pass": ["convolution_backward", "native_group_norm_backward", "threshold_backward", "max_pool2d_with_indices_backward", "avg_pool2d_backward", "_adaptive_avg_pool2d_backward"] + }, + "metadata": { + "description": "Core operations model testing fundamental operators: convolution, group norm, max pool with indices, avg pool, adaptive avg pool" + } +} \ No newline at end of file diff --git a/BackendBench/suite/models/toy_core_ops/toy_core_ops.py b/BackendBench/suite/models/toy_core_ops/toy_core_ops.py new file mode 100644 index 00000000..142774c5 --- /dev/null +++ b/BackendBench/suite/models/toy_core_ops/toy_core_ops.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 + +""" +Toy model that uses core PyTorch operators during training. + +This model is designed to trigger all of the following backward operators +when performing backpropagation: + +- ConvolutionBackward0 (convolution_backward) +- NativeGroupNormBackward0 (native_group_norm_backward) +- MaxPool2DWithIndicesBackward0 (max_pool2d_with_indices_backward) +- AvgPool2DBackward0 (avg_pool2d_backward) +- AdaptiveAvgPool2DBackward0 (_adaptive_avg_pool2d_backward) + +The model implements a CNN architecture with the following structure: +1. Conv2d -> GroupNorm -> ReLU (triggers convolution_backward, native_group_norm_backward) +2. MaxPool2d with indices (triggers max_pool2d_with_indices_backward) +3. Conv2d -> GroupNorm -> ReLU (triggers convolution_backward, native_group_norm_backward again) +4. AvgPool2d (triggers avg_pool2d_backward) +5. AdaptiveAvgPool2d (triggers _adaptive_avg_pool2d_backward) +6. Final Conv2d (triggers convolution_backward again) + +Usage: + python toy_core_ops.py + +This will create a model with default configuration and run a simple forward/backward pass +to demonstrate that all required backward operators are used. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ToyCoreOpsModel(nn.Module): + """ + Toy CNN model designed to test core PyTorch operators during training. + + The model uses a strategic combination of operations to ensure all target + backward operators are invoked during backpropagation: + + - Convolution layers for convolution_backward + - Group normalization for native_group_norm_backward + - Max pooling with indices for max_pool2d_with_indices_backward + - Average pooling for avg_pool2d_backward + - Adaptive average pooling for _adaptive_avg_pool2d_backward + """ + + def __init__(self, + in_channels: int = 3, + hidden_channels: int = 32, + out_channels: int = 8, + num_groups: int = 8, + seed: int = 42): + """ + Initialize the ToyCoreOpsModel. + + Args: + in_channels: Number of input channels (default: 3 for RGB) + hidden_channels: Number of hidden channels in conv layers + out_channels: Number of output channels + num_groups: Number of groups for GroupNorm (must divide hidden_channels) + seed: Random seed for deterministic weight initialization + + Raises: + ValueError: If hidden_channels is not divisible by num_groups + """ + super().__init__() + + # Validate group normalization constraints + if hidden_channels % num_groups != 0: + raise ValueError( + f"hidden_channels ({hidden_channels}) must be divisible " + f"by num_groups ({num_groups})" + ) + + # Store configuration + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.num_groups = num_groups + + # First convolution block (triggers convolution_backward) + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=hidden_channels, + kernel_size=3, + padding=1 + ) + + # First group normalization (triggers native_group_norm_backward) + self.group_norm1 = nn.GroupNorm( + num_groups=num_groups, + num_channels=hidden_channels + ) + + # Second convolution block (triggers convolution_backward again) + self.conv2 = nn.Conv2d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=3, + padding=1 + ) + + # Second group normalization (triggers native_group_norm_backward again) + self.group_norm2 = nn.GroupNorm( + num_groups=num_groups, + num_channels=hidden_channels + ) + + # Final convolution for output (triggers convolution_backward again) + self.conv_out = nn.Conv2d( + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=1 + ) + + # Initialize weights deterministically + self._initialize_weights(seed) + + def _initialize_weights(self, seed: int): + """ + Initialize model weights deterministically using the given seed. + + Args: + seed: Random seed for reproducible initialization + """ + # Set random seed for deterministic initialization + torch.manual_seed(seed) + + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.GroupNorm): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass that sets up the computational graph to trigger all target backward operators. + + Args: + x: Input tensor of shape (batch_size, in_channels, height, width) + + Returns: + Output tensor of shape (batch_size, out_channels, 4, 4) + + Note: + The output is always 4x4 regardless of input size due to the adaptive pooling layer. + """ + + # First conv block: Conv2d -> GroupNorm -> ReLU + # This will trigger: ConvolutionBackward0, NativeGroupNormBackward0 + x = self.conv1(x) + x = self.group_norm1(x) + x = F.relu(x) + + # Max pooling with indices (triggers MaxPool2DWithIndicesBackward0) + # We need to use return_indices=True to get the specific backward operator + x, indices = F.max_pool2d(x, kernel_size=2, return_indices=True) + + # Second conv block: Conv2d -> GroupNorm -> ReLU + # This will trigger: ConvolutionBackward0, NativeGroupNormBackward0 (again) + x = self.conv2(x) + x = self.group_norm2(x) + x = F.relu(x) + + # Average pooling (triggers AvgPool2DBackward0) + x = F.avg_pool2d(x, kernel_size=2) + + # Adaptive average pooling (triggers AdaptiveAvgPool2DBackward0) + # This ensures consistent output size regardless of input dimensions + x = F.adaptive_avg_pool2d(x, output_size=(4, 4)) + + # Final convolution (triggers ConvolutionBackward0 again) + x = self.conv_out(x) + + return x + + +def main(): + """ + Demonstrate the ToyCoreOpsModel with a simple forward/backward pass. + """ + print("ToyCoreOpsModel Demonstration") + print("=" * 50) + + # Create model with default configuration + model = ToyCoreOpsModel( + in_channels=3, + hidden_channels=32, + out_channels=8, + num_groups=8, + seed=42 # Deterministic initialization + ) + + # Create sample input + batch_size = 2 + input_tensor = torch.randn(batch_size, 3, 64, 64, requires_grad=True) + + print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters") + print(f"Input shape: {input_tensor.shape}") + + # Forward pass + model.train() + output = model(input_tensor) + expected_shape = torch.Size([batch_size, 8, 4, 4]) # Expected output shape + + print(f"Output shape: {output.shape}") + print(f"Expected shape: {expected_shape}") + print(f"Shape matches: {output.shape == expected_shape}") + + # Perform backward pass to actually trigger the operations + print("\nPerforming backward pass...") + loss = output.sum() + loss.backward() + print("✓ Backward pass completed successfully") + + # Check gradients were computed + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f"✓ Gradients computed for {grad_count}/{total_params} parameters") + + print(f"\n✓ Model demonstration completed successfully!") + print("This model is ready to be used with the Model Suite for testing core operators.") + + return model + + +if __name__ == "__main__": + main() \ No newline at end of file From 843cc97891ffd2d9e6e45c2fedde13abc4b045a1 Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 30 Sep 2025 08:28:45 +0000 Subject: [PATCH 02/16] add tests --- test/test_model_suite_correctness.py | 180 +++++++++++++++++++++ test/test_model_suite_integration.py | 223 +++++++++++++++++++++++++++ 2 files changed, 403 insertions(+) create mode 100644 test/test_model_suite_correctness.py create mode 100644 test/test_model_suite_integration.py diff --git a/test/test_model_suite_correctness.py b/test/test_model_suite_correctness.py new file mode 100644 index 00000000..4e557022 --- /dev/null +++ b/test/test_model_suite_correctness.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Essential tests for Model Suite PR #2: Full Model Testing & Results + +This test suite validates the core functionality: +1. FullModelTest class with eager/backend execution +2. Numerical correctness comparison +3. ModelSuite.test_model_correctness() integration +""" + +import logging +import os +import sys +import unittest + +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from BackendBench.suite.model import FullModelTest, load_toy_models, ModelSuite + +# Setup logging +logging.basicConfig(level=logging.WARNING) + + +class TestFullModelTest(unittest.TestCase): + """Test FullModelTest class functionality.""" + + @classmethod + def setUpClass(cls): + """Load toy models once for all tests.""" + cls.models = load_toy_models(toy_models_dir="BackendBench/suite/models") + assert len(cls.models) > 0, "Should load at least one model" + cls.model = next(m for m in cls.models if m["name"] == "toy_core_ops") + + def test_initialization(self): + """Test FullModelTest can be instantiated correctly.""" + test_config = self.model["config"]["test_configs"][0] + full_test = FullModelTest( + model_name=self.model["name"], + model_class=self.model["class"], + config=self.model["config"], + test_config=test_config, + ) + + self.assertEqual(full_test.model_name, self.model["name"]) + self.assertEqual(full_test.model_class, self.model["class"]) + + def test_eager_execution(self): + """Test model runs correctly in eager mode.""" + test_config = self.model["config"]["test_configs"][0] + full_test = FullModelTest( + self.model["name"], self.model["class"], self.model["config"], test_config + ) + + output, grads = full_test.run_with_backend(backend_enabled=False) + + # Verify output shape + batch_size = test_config["forward_args"]["batch_size"] + expected_shape = torch.Size([batch_size, 8, 4, 4]) + self.assertEqual(output.shape, expected_shape) + + # Verify gradients computed + self.assertGreater(len(grads), 0, "Should compute gradients") + + # Verify all gradients are valid + for grad in grads: + self.assertIsInstance(grad, torch.Tensor) + self.assertFalse(torch.isnan(grad).any(), "No NaN gradients") + self.assertFalse(torch.isinf(grad).any(), "No Inf gradients") + + def test_backend_execution(self): + """Test model runs with backend enabled.""" + test_config = self.model["config"]["test_configs"][0] + full_test = FullModelTest( + self.model["name"], self.model["class"], self.model["config"], test_config + ) + + output, grads = full_test.run_with_backend(backend_enabled=True) + + # Verify output shape + batch_size = test_config["forward_args"]["batch_size"] + expected_shape = torch.Size([batch_size, 8, 4, 4]) + self.assertEqual(output.shape, expected_shape) + + # Verify gradients computed + self.assertGreater(len(grads), 0, "Should compute gradients") + + def test_correctness_comparison(self): + """Test correctness comparison between eager and backend.""" + test_config = self.model["config"]["test_configs"][0] + full_test = FullModelTest( + self.model["name"], self.model["class"], self.model["config"], test_config + ) + + is_correct = full_test.test_correctness() + + # Result should be a boolean + self.assertIsInstance(is_correct, bool) + + # With existing kernels, test should pass + self.assertTrue(is_correct, "Backend should produce correct results") + + def test_multiple_configs(self): + """Test all model configurations run correctly.""" + for test_config in self.model["config"]["test_configs"]: + full_test = FullModelTest( + self.model["name"], self.model["class"], self.model["config"], test_config + ) + + output, grads = full_test.run_with_backend(backend_enabled=False) + + batch_size = test_config["forward_args"]["batch_size"] + expected_shape = torch.Size([batch_size, 8, 4, 4]) + self.assertEqual(output.shape, expected_shape, f"Config {test_config['name']} failed") + self.assertGreater(len(grads), 0, f"Config {test_config['name']} has no gradients") + + +class TestModelSuite(unittest.TestCase): + """Test ModelSuite.test_model_correctness() integration.""" + + def test_model_correctness_method_exists(self): + """Test that test_model_correctness method exists.""" + suite = ModelSuite() + self.assertTrue(hasattr(suite, "test_model_correctness")) + + def test_model_correctness_integration(self): + """Test ModelSuite.test_model_correctness() returns proper results.""" + suite = ModelSuite() + results = suite.test_model_correctness() + + # Verify results structure + self.assertIsInstance(results, dict, "Results should be a dictionary") + self.assertGreater(len(results), 0, "Should have results for at least one model") + + # Verify each model has config results + for model_name, model_results in results.items(): + self.assertIsInstance(model_results, dict, f"{model_name} results should be dict") + self.assertGreater(len(model_results), 0, f"{model_name} should have test configs") + + # Verify each config result is a boolean + for config_name, is_correct in model_results.items(): + self.assertIsInstance( + is_correct, bool, f"{model_name}::{config_name} should be bool" + ) + + def test_results_aggregation(self): + """Test that results can be aggregated for scoring.""" + suite = ModelSuite() + results = suite.test_model_correctness() + + # Calculate aggregate statistics + total_tests = sum(len(model_results) for model_results in results.values()) + total_passed = sum( + sum(1 for result in model_results.values() if result) + for model_results in results.values() + ) + + self.assertGreater(total_tests, 0, "Should have at least one test") + self.assertLessEqual(total_passed, total_tests, "Passed <= Total") + self.assertGreaterEqual(total_passed, 0, "Passed >= 0") + + def test_empty_filter(self): + """Test suite handles empty model list gracefully.""" + suite = ModelSuite(filter=["nonexistent_model"]) + self.assertEqual(len(suite.models), 0, "Should have no models") + + # Should not crash when running on empty list + results = suite.test_model_correctness() + self.assertEqual(len(results), 0, "Should have no results") + + +if __name__ == "__main__": + # Run tests + unittest.main(verbosity=2) diff --git a/test/test_model_suite_integration.py b/test/test_model_suite_integration.py new file mode 100644 index 00000000..43d04c4a --- /dev/null +++ b/test/test_model_suite_integration.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Essential integration tests for Model Suite PR #3: Final Polish + +This test suite validates: +1. Complete CLI workflow with model suite +2. Filtering functionality +3. Error handling for invalid backends +4. Operator-level and model-level testing integration +""" + +import os +import subprocess +import sys +import unittest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from BackendBench.suite.model import ModelSuite + + +class TestModelSuiteCLI(unittest.TestCase): + """Test CLI integration for model suite.""" + + def test_complete_workflow(self): + """Test complete workflow from CLI.""" + result = subprocess.run( + [ + "python", + "-m", + "BackendBench.scripts.main", + "--suite", + "model", + "--backend", + "directory", + "--ops-directory", + "generated_kernels", + "--disable-output-logs", + "--log-level", + "ERROR", + ], + capture_output=True, + text=True, + timeout=60, + ) + + self.assertEqual(result.returncode, 0, "CLI should succeed") + self.assertIn("FULL MODEL TESTING", result.stdout) + self.assertIn("Model Suite Score", result.stdout) + self.assertIn("toy_core_ops", result.stdout) + + def test_filtering_by_model(self): + """Test filtering models by name.""" + result = subprocess.run( + [ + "python", + "-m", + "BackendBench.scripts.main", + "--suite", + "model", + "--backend", + "directory", + "--ops-directory", + "generated_kernels", + "--ops", + "toy_core_ops", + "--disable-output-logs", + "--log-level", + "ERROR", + ], + capture_output=True, + text=True, + timeout=60, + ) + + self.assertEqual(result.returncode, 0, "Filtered run should succeed") + self.assertIn("toy_core_ops", result.stdout) + + def test_invalid_backend_error(self): + """Test that model suite rejects invalid backends.""" + result = subprocess.run( + [ + "python", + "-m", + "BackendBench.scripts.main", + "--suite", + "model", + "--backend", + "aten", + "--disable-output-logs", + "--log-level", + "ERROR", + ], + capture_output=True, + text=True, + timeout=30, + ) + + self.assertNotEqual(result.returncode, 0, "Should fail with invalid backend") + self.assertIn("model suite only supports directory backend", result.stderr.lower()) + + def test_empty_filter(self): + """Test handling of nonexistent model filter.""" + result = subprocess.run( + [ + "python", + "-m", + "BackendBench.scripts.main", + "--suite", + "model", + "--backend", + "directory", + "--ops-directory", + "generated_kernels", + "--ops", + "nonexistent_model", + "--disable-output-logs", + "--log-level", + "ERROR", + ], + capture_output=True, + text=True, + timeout=60, + ) + + # Should succeed but with no models + self.assertEqual(result.returncode, 0, "Should succeed with empty filter") + self.assertIn("No model tests were run", result.stdout) + + +class TestModelSuiteIntegration(unittest.TestCase): + """Test ModelSuite integration and initialization.""" + + def test_initialization_variants(self): + """Test ModelSuite initialization with various options.""" + # Default initialization + suite1 = ModelSuite() + self.assertGreater(len(suite1.models), 0, "Should load models by default") + + # With filter + suite2 = ModelSuite(filter=["toy_core_ops"]) + self.assertEqual(len(suite2.models), 1, "Should load exactly 1 model") + self.assertEqual(suite2.models[0]["name"], "toy_core_ops") + + # Empty filter + suite3 = ModelSuite(filter=["nonexistent"]) + self.assertEqual(len(suite3.models), 0, "Should load no models with invalid filter") + + def test_operator_level_integration(self): + """Test that operator-level testing works via __iter__.""" + suite = ModelSuite() + op_tests = list(suite) + + self.assertGreater(len(op_tests), 0, "Should generate operator tests") + + # Verify first test structure + if len(op_tests) > 0: + first_test = op_tests[0] + self.assertTrue(hasattr(first_test, "op")) + self.assertTrue(hasattr(first_test, "correctness_tests")) + self.assertTrue(hasattr(first_test, "performance_tests")) + + def test_model_level_integration(self): + """Test that model-level testing works.""" + suite = ModelSuite() + results = suite.test_model_correctness() + + self.assertIsInstance(results, dict) + self.assertGreater(len(results), 0, "Should have results") + + # Verify structure + for model_name, config_results in results.items(): + self.assertIsInstance(config_results, dict) + for config_name, is_correct in config_results.items(): + self.assertIsInstance(is_correct, bool) + + def test_output_format(self): + """Test that CLI output is properly formatted.""" + result = subprocess.run( + [ + "python", + "-m", + "BackendBench.scripts.main", + "--suite", + "model", + "--backend", + "directory", + "--ops-directory", + "generated_kernels", + "--disable-output-logs", + "--log-level", + "ERROR", + ], + capture_output=True, + text=True, + timeout=60, + ) + + output = result.stdout + + # Check for expected sections + self.assertIn("correctness score", output.lower()) + self.assertIn("performance score", output.lower()) + self.assertIn("FULL MODEL TESTING", output) + self.assertIn("Model Correctness Results:", output) + self.assertIn("Model Suite Score:", output) + + # Check for formatting + self.assertIn("=" * 80, output) + self.assertIn("-" * 80, output) + + # Check for pass/fail indicators + has_pass_fail = "✓ PASS" in output or "✗ FAIL" in output + self.assertTrue(has_pass_fail, "Should show pass/fail indicators") + + +if __name__ == "__main__": + unittest.main(verbosity=2) From f2e1f619afcd73529d57a05a871428b031fec0d1 Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 30 Sep 2025 09:23:41 +0000 Subject: [PATCH 03/16] edits --- BackendBench/eval_model.py | 215 +++++++++++++ BackendBench/scripts/main.py | 58 ++-- BackendBench/suite/model.py | 299 +++++++++++------- .../models/toy_core_ops/toy_core_ops.json | 34 +- test/test_model_suite_correctness.py | 79 +++-- test/test_model_suite_integration.py | 56 +++- 6 files changed, 524 insertions(+), 217 deletions(-) create mode 100644 BackendBench/eval_model.py diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py new file mode 100644 index 00000000..7783a5c3 --- /dev/null +++ b/BackendBench/eval_model.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""Model-level evaluation utilities for testing full model correctness.""" + +import logging +import os +import traceback +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + +from BackendBench.utils import deserialize_args + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelCorrectnessTestResult: + """Result from testing a model configuration.""" + + model_name: str + test_name: str + is_correct: bool = False + error_msg: str = "" + error_type: str = "" + traceback: str = "" + output_match: bool = False + gradients_match: bool = False + num_gradients: int = 0 + + +def eval_model_correctness_test( + model_name: str, + model_class: type, + model_config: Dict[str, Any], + test_name: str, + test_args: str, + kernel_dir: str = None, + atol: float = 1e-6, + rtol: float = 1e-5, +) -> ModelCorrectnessTestResult: + """Evaluate model correctness by comparing eager vs backend execution. + + Similar to eval_correctness_test in eval.py, but for full models instead of individual ops. + + Args: + model_name: Name of the model being tested + model_class: Model class to instantiate + model_config: Model configuration dict with init_args + test_name: Name of this test configuration + test_args: Serialized arguments string for forward pass + kernel_dir: Optional directory containing kernels for backend + atol: Absolute tolerance for torch.allclose + rtol: Relative tolerance for torch.allclose + + Returns: + ModelCorrectnessTestResult with detailed comparison results + """ + try: + # Run in eager mode (reference) + eager_out, eager_grads = _run_model( + model_class, model_config, test_args, backend_enabled=False, kernel_dir=kernel_dir + ) + + # Run with backend (implementation) + backend_out, backend_grads = _run_model( + model_class, model_config, test_args, backend_enabled=True, kernel_dir=kernel_dir + ) + + # Compare outputs + output_match = torch.allclose(eager_out, backend_out, atol=atol, rtol=rtol) + + # Compare gradients + gradients_match = True + if len(eager_grads) != len(backend_grads): + gradients_match = False + else: + for eager_grad, backend_grad in zip(eager_grads, backend_grads): + if not torch.allclose(eager_grad, backend_grad, atol=atol, rtol=rtol): + gradients_match = False + break + + is_correct = output_match and gradients_match + + return ModelCorrectnessTestResult( + model_name=model_name, + test_name=test_name, + is_correct=is_correct, + output_match=output_match, + gradients_match=gradients_match, + num_gradients=len(eager_grads), + ) + + except Exception as e: + error_msg = f"Model {model_name}::{test_name} failed: {e}" + logger.error(error_msg) + return ModelCorrectnessTestResult( + model_name=model_name, + test_name=test_name, + is_correct=False, + error_msg=error_msg, + error_type=str(type(e)), + traceback=traceback.format_exc(), + ) + + +def _run_model( + model_class: type, + model_config: Dict[str, Any], + test_args: str, + backend_enabled: bool, + kernel_dir: str = None, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Run model with or without backend enabled. + + Args: + model_class: Model class to instantiate + model_config: Model configuration dict with init_args + test_args: Serialized arguments string for forward pass + backend_enabled: If True, use BackendBench context manager + kernel_dir: Optional directory containing kernels + + Returns: + Tuple of (output, gradients) where: + - output: Model output tensor (detached) + - gradients: List of gradient tensors [input_grad, param1_grad, ...] + """ + import BackendBench + + # Deserialize test arguments + args, kwargs = deserialize_args(test_args) + + # Extract model initialization args + init_args = model_config.get("init_args", {}).copy() + + # Handle seed: use runtime_seed if required, otherwise use seed from init_args + if model_config.get("requires_init_seed", False): + # Use the generated runtime seed + seed = model_config["runtime_seed"] + init_args["seed"] = seed + else: + # Use seed from init_args or default + seed = init_args.get("seed", 42) + + # Set seed for deterministic behavior + torch.manual_seed(seed) + + # Create fresh model instance + model = model_class(**init_args) + model.train() + + # Move model to same device as input (typically CUDA) + # Check both args and kwargs for tensor + input_tensor = None + if args and isinstance(args[0], torch.Tensor): + input_tensor = args[0] + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): + input_tensor = kwargs["x"] + + if input_tensor is not None: + device = input_tensor.device + model = model.to(device) + + # Ensure input has requires_grad for gradient computation + if args and isinstance(args[0], torch.Tensor): + x = args[0] + if not x.requires_grad: + x = x.clone().detach().requires_grad_(True) + args = [x] + list(args[1:]) + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): + x = kwargs["x"] + if not x.requires_grad: + x = x.clone().detach().requires_grad_(True) + kwargs["x"] = x + + # Run forward + backward with or without backend + if backend_enabled: + # Use context manager to enable backend + if kernel_dir is None: + kernel_dir = os.path.join(os.getcwd(), "generated_kernels") + + with BackendBench.BackendBench.enable(kernel_dir=kernel_dir): + output = model(*args, **kwargs) + loss = output.sum() + loss.backward() + else: + # Run in eager mode (no backend) + output = model(*args, **kwargs) + loss = output.sum() + loss.backward() + + # Collect gradients: [input_grad, param1_grad, ...] + grads = [] + + # Input gradient - check both args and kwargs + input_grad = None + if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None: + input_grad = args[0].grad + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None: + input_grad = kwargs["x"].grad + + if input_grad is not None: + grads.append(input_grad.clone()) + + # Parameter gradients + for param in model.parameters(): + if param.grad is not None: + grads.append(param.grad.clone()) + + return output.detach(), grads diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 9422d0db..4f5d2bb4 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -64,7 +64,13 @@ def setup_logging(log_level): "--ops", default=None, type=str, - help="Comma-separated list of ops to run", + help="Comma-separated list of ops to run (not supported for model suite)", +) +@click.option( + "--model-filter", + default=None, + type=str, + help="Comma-separated list of models to run (only for model suite)", ) @click.option( "--topn-inputs", @@ -148,6 +154,7 @@ def cli( suite, backend, ops, + model_filter, topn_inputs, llm_attempts, llm_model, @@ -167,12 +174,22 @@ def cli( if check_overhead_dominated_ops: raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite") - if suite == "model" and backend != "directory": - raise ValueError("model suite only supports directory backend") + if suite == "model": + if backend != "directory": + raise ValueError("model suite only supports directory backend") + if ops is not None: + raise ValueError( + "--ops filter is not supported for model suite. Use --model-filter instead" + ) + + if suite != "model" and model_filter is not None: + raise ValueError("--model-filter is only supported for model suite") setup_logging(log_level) if ops: ops = ops.split(",") + if model_filter: + model_filter = model_filter.split(",") suite = { "smoke": lambda: SmokeTestSuite, @@ -195,7 +212,7 @@ def cli( torch.bfloat16, filter=ops, ), - "model": lambda: ModelSuite(filter=ops), + "model": lambda: ModelSuite(filter=model_filter), }[suite]() backend_name = backend @@ -298,40 +315,11 @@ def cli( ) # Add full model testing for model suite - if suite.name == "model" and hasattr(suite, 'test_model_correctness'): - print("\n" + "="*80) - print("FULL MODEL TESTING") - print("="*80) - + if suite.name == "model": # Pass ops_directory as kernel_dir for directory backend kernel_dir = ops_directory if backend_name == "directory" else None model_results = suite.test_model_correctness(kernel_dir=kernel_dir) - - # Print results - print("\nModel Correctness Results:") - print("-"*80) - total_passed = 0 - total_tests = 0 - for model_name, test_results in model_results.items(): - passed = sum(1 for result in test_results.values() if result) - total = len(test_results) - total_passed += passed - total_tests += total - percentage = (passed / total * 100) if total > 0 else 0 - print(f" {model_name}: {passed}/{total} configs passed ({percentage:.1f}%)") - - # Show individual config results - for config_name, is_correct in test_results.items(): - status = "✓ PASS" if is_correct else "✗ FAIL" - print(f" {config_name}: {status}") - - print("-"*80) - if total_tests > 0: - overall_percentage = total_passed / total_tests * 100 - print(f"\nModel Suite Score: {total_passed}/{total_tests} ({overall_percentage:.1f}%)") - else: - print("\nNo model tests were run") - print("="*80) + suite.print_model_correctness_results(model_results) command = "python -m BackendBench.scripts.main " + " ".join(sys.argv[1:]) diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index d9ca3050..fa502e75 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -12,20 +12,37 @@ 2. Model-level correctness testing (via test_model_correctness, new functionality) """ -import json -import os import importlib.util +import json import logging +import os +from typing import Any, Dict, List, Optional + import torch -from typing import Dict, List, Any, Optional -from BackendBench.utils import get_pytorch_op -from .torchbench import TorchBenchTestSuite, TorchBenchTest +from BackendBench.data_loaders import load_ops_from_source, op_list_to_benchmark_dict +from BackendBench.utils import deserialize_args + +from .torchbench import TorchBenchOpTest, TorchBenchTestSuite logger = logging.getLogger(__name__) +# Cache for torchbench ops to avoid reloading +_TORCHBENCH_OPS_CACHE = None + -def load_toy_models(toy_models_dir: str = "models", filter: Optional[List[str]] = None) -> List[Dict[str, Any]]: +def _get_torchbench_ops(): + """Get list of available ops from torchbench dataset (cached).""" + global _TORCHBENCH_OPS_CACHE + if _TORCHBENCH_OPS_CACHE is None: + ops_list = load_ops_from_source(source=None, format="parquet") + _TORCHBENCH_OPS_CACHE = set(op_list_to_benchmark_dict(ops_list).keys()) + return _TORCHBENCH_OPS_CACHE + + +def load_toy_models( + toy_models_dir: str = "models", filter: Optional[List[str]] = None +) -> List[Dict[str, Any]]: """Load models using strict naming convention: folder_name/folder_name.py + folder_name.json Args: @@ -41,8 +58,7 @@ def load_toy_models(toy_models_dir: str = "models", filter: Optional[List[str]] models = [] if not os.path.exists(toy_models_dir): - logger.warning(f"Toy models directory not found: {toy_models_dir}") - return models + raise FileNotFoundError(f"Toy models directory not found: {toy_models_dir}") for model_name in os.listdir(toy_models_dir): # Apply filter if specified @@ -59,16 +75,22 @@ def load_toy_models(toy_models_dir: str = "models", filter: Optional[List[str]] # Check both files exist if not os.path.exists(model_file): + if filter is not None and model_name in filter: + # If the model was explicitly requested but not found, raise an error + raise FileNotFoundError(f"Model file not found: {model_file}") logger.warning(f"Model file not found: {model_file}") continue if not os.path.exists(config_file): + if filter is not None and model_name in filter: + # If the model was explicitly requested but not found, raise an error + raise FileNotFoundError(f"Config file not found: {config_file}") logger.warning(f"Config file not found: {config_file}") continue try: # Load config - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = json.load(f) # Load model class dynamically @@ -80,9 +102,11 @@ def load_toy_models(toy_models_dir: str = "models", filter: Optional[List[str]] model_class = None for attr_name in dir(module): attr = getattr(module, attr_name) - if (isinstance(attr, type) and - attr_name.endswith("Model") and - hasattr(attr, "forward")): + if ( + isinstance(attr, type) + and attr_name.endswith("Model") + and hasattr(attr, "forward") + ): model_class = attr break @@ -90,63 +114,54 @@ def load_toy_models(toy_models_dir: str = "models", filter: Optional[List[str]] logger.error(f"No model class found in {model_file}") continue - models.append({ - "name": model_name, - "class": model_class, - "config": config - }) + # Generate runtime seed if required + if config.get("model_config", {}).get("requires_init_seed", False): + import random + + runtime_seed = random.randint(0, 2**31 - 1) + config["model_config"]["runtime_seed"] = runtime_seed + logger.debug(f"Generated runtime seed {runtime_seed} for {model_name}") + + models.append({"name": model_name, "class": model_class, "config": config}) logger.info(f"Loaded model: {model_name}") except Exception as e: + if filter is not None and model_name in filter: + # If the model was explicitly requested but failed to load, raise an error + raise RuntimeError(f"Failed to load model {model_name}: {e}") logger.error(f"Failed to load {model_name}: {e}") continue + # If a filter was specified but no models were loaded, raise an error + if filter is not None and len(models) == 0: + raise ValueError(f"No models found matching filter: {filter}") + return models -def _create_torchbench_op_test(model_name: str, model_class, config: Dict[str, Any], op_name: str): - """Create a TorchBenchOpTest for a specific operator from a toy model. +def _create_op_test(op_name: str, inputs: List[str]): + """Create a TorchBenchOpTest for a specific operator. Args: - model_name: Name of the model - model_class: Model class - config: Configuration dictionary - op_name: Operator name (e.g., "conv2d", "relu") + op_name: Operator name in aten format (e.g., "aten.conv2d.default") + inputs: List of serialized input strings Returns: TorchBenchOpTest instance compatible with existing evaluation infrastructure + + Raises: + ValueError: If the op is not in the torchbench dataset """ - from .torchbench import TorchBenchOpTest - from BackendBench.utils import serialize_args - - # Generate test inputs from model configs - inputs = [] - for test_config in config["test_configs"]: - # Extract input shape from test config - forward_args = test_config["forward_args"] - batch_size = forward_args["batch_size"] - input_shape = forward_args["input_shape"] - - # Create input tensor - full_shape = [batch_size] + input_shape - input_tensor = torch.randn(*full_shape) - - # Serialize the input for TorchBenchOpTest - # TorchBenchOpTest expects serialized inputs (strings) - serialized = serialize_args([input_tensor], {}) - inputs.append(serialized) - - # Create TorchBenchOpTest - it expects op name and inputs - # We need to convert op_name to full torch.ops format - op = get_pytorch_op(op_name) - if op is None: - raise ValueError(f"Could not find PyTorch operation for {op_name}") - - # Get the full op string (e.g., "aten.conv2d.default") - op_str = str(op).replace("torch.ops.", "") + # Check that the op is in the torchbench dataset + torchbench_ops = _get_torchbench_ops() + if op_name not in torchbench_ops: + raise ValueError( + f"Operator {op_name} is not in the torchbench dataset. " + f"Only ops from the torchbench dataset can be tested." + ) # Create the test with serialized inputs - return TorchBenchOpTest(op_str, inputs, topn=None) + return TorchBenchOpTest(op_name, inputs, topn=None) class FullModelTest: @@ -156,19 +171,28 @@ class FullModelTest: in both eager mode and backend mode, then comparing the results. """ - def __init__(self, model_name: str, model_class, config: Dict[str, Any], test_config: Dict[str, Any]): + def __init__( + self, + model_name: str, + model_class, + model_config: Dict[str, Any], + test_name: str, + test_args: str, + ): """Initialize FullModelTest. Args: model_name: Name of the model being tested model_class: Model class to instantiate - config: Full model configuration including model_config - test_config: Specific test configuration with forward_args + model_config: Model configuration dict with init_args + test_name: Name of this test configuration + test_args: Serialized arguments string for forward pass """ self.model_name = model_name self.model_class = model_class - self.config = config - self.test_config = test_config + self.model_config = model_config + self.test_name = test_name + self.test_args = test_args def run_with_backend(self, backend_enabled: bool, kernel_dir: str = None) -> tuple: """Run model with backend enabled or disabled. @@ -184,27 +208,51 @@ def run_with_backend(self, backend_enabled: bool, kernel_dir: str = None) -> tup """ import BackendBench - # Extract model configuration - model_config = self.config["model_config"]["init_args"] + # Deserialize test arguments + args, kwargs = deserialize_args(self.test_args) - # Extract input configuration - forward_args = self.test_config["forward_args"] - batch_size = forward_args["batch_size"] - input_shape = forward_args["input_shape"] + # Extract model initialization args + init_args = self.model_config.get("init_args", {}).copy() - # Create full input shape: [batch_size, *input_shape] - full_shape = [batch_size] + input_shape + # Handle seed: use runtime_seed if required, otherwise use seed from init_args + if self.model_config.get("requires_init_seed", False): + # Use the generated runtime seed + seed = self.model_config["runtime_seed"] + init_args["seed"] = seed + else: + # Use seed from init_args or default + seed = init_args.get("seed", 42) # Set seed for deterministic behavior - seed = model_config.get("seed", 42) torch.manual_seed(seed) # Create fresh model instance - model = self.model_class(**model_config) + model = self.model_class(**init_args) model.train() - # Create input tensor with requires_grad for input gradient - x = torch.randn(*full_shape, requires_grad=True) + # Move model to same device as input (typically CUDA) + # Check both args and kwargs for tensor + input_tensor = None + if args and isinstance(args[0], torch.Tensor): + input_tensor = args[0] + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): + input_tensor = kwargs["x"] + + if input_tensor is not None: + device = input_tensor.device + model = model.to(device) + + # Ensure input has requires_grad for gradient computation + if args and isinstance(args[0], torch.Tensor): + x = args[0] + if not x.requires_grad: + x = x.clone().detach().requires_grad_(True) + args = [x] + list(args[1:]) + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): + x = kwargs["x"] + if not x.requires_grad: + x = x.clone().detach().requires_grad_(True) + kwargs["x"] = x # Run forward + backward with or without backend if backend_enabled: @@ -214,21 +262,29 @@ def run_with_backend(self, backend_enabled: bool, kernel_dir: str = None) -> tup kernel_dir = os.path.join(os.getcwd(), "generated_kernels") with BackendBench.BackendBench.enable(kernel_dir=kernel_dir): - output = model(x) + output = model(*args, **kwargs) loss = output.sum() loss.backward() else: # Run in eager mode (no backend) - output = model(x) + output = model(*args, **kwargs) loss = output.sum() loss.backward() # Collect gradients: [input_grad, param1_grad, param2_grad, ...] grads = [] - # Input gradient - if x.grad is not None: - grads.append(x.grad.clone()) + # Input gradient - check both args and kwargs + input_grad = None + if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None: + input_grad = args[0].grad + elif ( + "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None + ): + input_grad = kwargs["x"].grad + + if input_grad is not None: + grads.append(input_grad.clone()) # Parameter gradients for param in model.parameters(): @@ -257,13 +313,13 @@ def test_correctness(self, atol=1e-6, rtol=1e-5, kernel_dir: str = None) -> bool # Compare outputs if not torch.allclose(eager_out, backend_out, atol=atol, rtol=rtol): - logger.debug(f"{self.model_name}::{self.test_config['name']}: Output mismatch") + logger.debug(f"{self.model_name}::{self.test_name}: Output mismatch") return False # Compare number of gradients if len(eager_grads) != len(backend_grads): logger.debug( - f"{self.model_name}::{self.test_config['name']}: " + f"{self.model_name}::{self.test_name}: " f"Gradient count mismatch ({len(eager_grads)} vs {len(backend_grads)})" ) return False @@ -271,16 +327,13 @@ def test_correctness(self, atol=1e-6, rtol=1e-5, kernel_dir: str = None) -> bool # Compare each gradient for i, (eager_grad, backend_grad) in enumerate(zip(eager_grads, backend_grads)): if not torch.allclose(eager_grad, backend_grad, atol=atol, rtol=rtol): - logger.debug( - f"{self.model_name}::{self.test_config['name']}: " - f"Gradient {i} mismatch" - ) + logger.debug(f"{self.model_name}::{self.test_name}: Gradient {i} mismatch") return False return True except Exception as e: - logger.error(f"{self.model_name}::{self.test_config['name']}: Correctness test failed: {e}") + logger.error(f"{self.model_name}::{self.test_name}: Correctness test failed: {e}") return False @@ -292,7 +345,9 @@ class ModelSuite(TorchBenchTestSuite): 2. Model-level correctness testing via test_model_correctness() (Model Suite specific) """ - def __init__(self, name: str = "model", filter: Optional[List[str]] = None, models_dir: str = None): + def __init__( + self, name: str = "model", filter: Optional[List[str]] = None, models_dir: str = None + ): """Initialize ModelSuite. Args: @@ -316,32 +371,15 @@ def __init__(self, name: str = "model", filter: Optional[List[str]] = None, mode def __iter__(self): """Yield operator tests from all models (TorchBench approach). - This method enables operator-level testing using the inherited - TorchBench infrastructure. Returns TorchBenchOpTest instances. + This method extracts operators by tracing model execution, then creates + TorchBenchOpTest instances for testing via the inherited infrastructure. """ - for model in self.models: - # Extract operators from config - if "expected_operators" not in model["config"]: - logger.warning(f"Model {model['name']} has no expected_operators in config") - continue + # Trace each model to extract operators - expected_ops = model["config"]["expected_operators"] - - # Yield forward pass operators - if "forward_pass" in expected_ops: - for op_name in expected_ops["forward_pass"]: - try: - yield _create_torchbench_op_test(model["name"], model["class"], model["config"], op_name) - except Exception as e: - logger.error(f"Failed to create test for forward op {op_name}: {e}") - - # Yield backward pass operators - if "backward_pass" in expected_ops: - for op_name in expected_ops["backward_pass"]: - try: - yield _create_torchbench_op_test(model["name"], model["class"], model["config"], op_name) - except Exception as e: - logger.error(f"Failed to create test for backward op {op_name}: {e}") + # For now, return empty iterator since operator extraction from model tracing + # is not yet implemented. The model suite focuses on full model testing via + # test_model_correctness() method. + return iter([]) def test_model_correctness(self, kernel_dir: str = None) -> Dict[str, Dict[str, bool]]: """Test full model correctness for all models and configurations. @@ -362,15 +400,20 @@ def test_model_correctness(self, kernel_dir: str = None) -> Dict[str, Dict[str, model_results = {} # Test each configuration for this model - for test_config in model["config"]["test_configs"]: + if "model_tests" not in model["config"]: + logger.warning(f"Model {model['name']} has no model_tests in config") + continue + + # model_tests is a dict mapping test_name -> serialized_args + for test_name, test_args in model["config"]["model_tests"].items(): test = FullModelTest( model_name=model["name"], model_class=model["class"], - config=model["config"], - test_config=test_config + model_config=model["config"].get("model_config", {}), + test_name=test_name, + test_args=test_args, ) - test_name = test_config["name"] is_correct = test.test_correctness(kernel_dir=kernel_dir) model_results[test_name] = is_correct @@ -380,3 +423,39 @@ def test_model_correctness(self, kernel_dir: str = None) -> Dict[str, Dict[str, results[model["name"]] = model_results return results + + def print_model_correctness_results(self, results: Dict[str, Dict[str, bool]]): + """Print formatted model correctness results. + + Args: + results: Dictionary mapping model_name -> {test_name -> bool} + """ + print("\n" + "=" * 80) + print("FULL MODEL TESTING") + print("=" * 80) + print("\nModel Correctness Results:") + print("-" * 80) + + total_passed = 0 + total_tests = 0 + + for model_name, test_results in results.items(): + passed = sum(1 for result in test_results.values() if result) + total = len(test_results) + total_passed += passed + total_tests += total + percentage = (passed / total * 100) if total > 0 else 0 + print(f" {model_name}: {passed}/{total} configs passed ({percentage:.1f}%)") + + # Show individual config results + for config_name, is_correct in test_results.items(): + status = "✓ PASS" if is_correct else "✗ FAIL" + print(f" {config_name}: {status}") + + print("-" * 80) + if total_tests > 0: + overall_percentage = total_passed / total_tests * 100 + print(f"\nModel Suite Score: {total_passed}/{total_tests} ({overall_percentage:.1f}%)") + else: + print("\nNo model tests were run") + print("=" * 80) diff --git a/BackendBench/suite/models/toy_core_ops/toy_core_ops.json b/BackendBench/suite/models/toy_core_ops/toy_core_ops.json index 7cad404e..b9bab0e6 100644 --- a/BackendBench/suite/models/toy_core_ops/toy_core_ops.json +++ b/BackendBench/suite/models/toy_core_ops/toy_core_ops.json @@ -1,39 +1,17 @@ { "model_config": { + "requires_init_seed": true, "init_args": { "in_channels": 3, "hidden_channels": 32, "out_channels": 8, - "num_groups": 8, - "seed": 42 + "num_groups": 8 } }, - "test_configs": [ - { - "name": "small_batch", - "forward_args": { - "batch_size": 2, - "input_shape": [3, 32, 32] - } - }, - { - "name": "medium_batch", - "forward_args": { - "batch_size": 4, - "input_shape": [3, 64, 64] - } - }, - { - "name": "large_input", - "forward_args": { - "batch_size": 2, - "input_shape": [3, 128, 128] - } - } - ], - "expected_operators": { - "forward_pass": ["conv2d", "native_group_norm", "relu", "max_pool2d_with_indices", "avg_pool2d", "adaptive_avg_pool2d"], - "backward_pass": ["convolution_backward", "native_group_norm_backward", "threshold_backward", "max_pool2d_with_indices_backward", "avg_pool2d_backward", "_adaptive_avg_pool2d_backward"] + "model_tests": { + "small_batch": "([], {'x': T([2, 3, 32, 32], f32)})", + "medium_batch": "([], {'x': T([4, 3, 64, 64], f32)})", + "large_input": "([], {'x': T([2, 3, 128, 128], f32)})" }, "metadata": { "description": "Core operations model testing fundamental operators: convolution, group norm, max pool with indices, avg pool, adaptive avg pool" diff --git a/test/test_model_suite_correctness.py b/test/test_model_suite_correctness.py index 4e557022..7508c2ac 100644 --- a/test/test_model_suite_correctness.py +++ b/test/test_model_suite_correctness.py @@ -40,12 +40,14 @@ def setUpClass(cls): def test_initialization(self): """Test FullModelTest can be instantiated correctly.""" - test_config = self.model["config"]["test_configs"][0] + test_name = list(self.model["config"]["model_tests"].keys())[0] + test_args = self.model["config"]["model_tests"][test_name] full_test = FullModelTest( model_name=self.model["name"], model_class=self.model["class"], - config=self.model["config"], - test_config=test_config, + model_config=self.model["config"]["model_config"], + test_name=test_name, + test_args=test_args, ) self.assertEqual(full_test.model_name, self.model["name"]) @@ -53,16 +55,20 @@ def test_initialization(self): def test_eager_execution(self): """Test model runs correctly in eager mode.""" - test_config = self.model["config"]["test_configs"][0] + test_name = "small_batch" + test_args = self.model["config"]["model_tests"][test_name] full_test = FullModelTest( - self.model["name"], self.model["class"], self.model["config"], test_config + self.model["name"], + self.model["class"], + self.model["config"]["model_config"], + test_name, + test_args, ) output, grads = full_test.run_with_backend(backend_enabled=False) - # Verify output shape - batch_size = test_config["forward_args"]["batch_size"] - expected_shape = torch.Size([batch_size, 8, 4, 4]) + # Verify output shape (batch_size=2 from small_batch config) + expected_shape = torch.Size([2, 8, 4, 4]) self.assertEqual(output.shape, expected_shape) # Verify gradients computed @@ -76,16 +82,20 @@ def test_eager_execution(self): def test_backend_execution(self): """Test model runs with backend enabled.""" - test_config = self.model["config"]["test_configs"][0] + test_name = "small_batch" + test_args = self.model["config"]["model_tests"][test_name] full_test = FullModelTest( - self.model["name"], self.model["class"], self.model["config"], test_config + self.model["name"], + self.model["class"], + self.model["config"]["model_config"], + test_name, + test_args, ) output, grads = full_test.run_with_backend(backend_enabled=True) - # Verify output shape - batch_size = test_config["forward_args"]["batch_size"] - expected_shape = torch.Size([batch_size, 8, 4, 4]) + # Verify output shape (batch_size=2 from small_batch config) + expected_shape = torch.Size([2, 8, 4, 4]) self.assertEqual(output.shape, expected_shape) # Verify gradients computed @@ -93,9 +103,14 @@ def test_backend_execution(self): def test_correctness_comparison(self): """Test correctness comparison between eager and backend.""" - test_config = self.model["config"]["test_configs"][0] + test_name = "small_batch" + test_args = self.model["config"]["model_tests"][test_name] full_test = FullModelTest( - self.model["name"], self.model["class"], self.model["config"], test_config + self.model["name"], + self.model["class"], + self.model["config"]["model_config"], + test_name, + test_args, ) is_correct = full_test.test_correctness() @@ -108,17 +123,28 @@ def test_correctness_comparison(self): def test_multiple_configs(self): """Test all model configurations run correctly.""" - for test_config in self.model["config"]["test_configs"]: + # Expected shapes for each config + # Note: Output size is always 4x4 due to adaptive_avg_pool2d([4, 4]) + expected_shapes = { + "small_batch": torch.Size([2, 8, 4, 4]), + "medium_batch": torch.Size([4, 8, 4, 4]), + "large_input": torch.Size([2, 8, 4, 4]), + } + + for test_name, test_args in self.model["config"]["model_tests"].items(): full_test = FullModelTest( - self.model["name"], self.model["class"], self.model["config"], test_config + self.model["name"], + self.model["class"], + self.model["config"]["model_config"], + test_name, + test_args, ) output, grads = full_test.run_with_backend(backend_enabled=False) - batch_size = test_config["forward_args"]["batch_size"] - expected_shape = torch.Size([batch_size, 8, 4, 4]) - self.assertEqual(output.shape, expected_shape, f"Config {test_config['name']} failed") - self.assertGreater(len(grads), 0, f"Config {test_config['name']} has no gradients") + expected_shape = expected_shapes[test_name] + self.assertEqual(output.shape, expected_shape, f"Config {test_name} failed") + self.assertGreater(len(grads), 0, f"Config {test_name} has no gradients") class TestModelSuite(unittest.TestCase): @@ -166,13 +192,10 @@ def test_results_aggregation(self): self.assertGreaterEqual(total_passed, 0, "Passed >= 0") def test_empty_filter(self): - """Test suite handles empty model list gracefully.""" - suite = ModelSuite(filter=["nonexistent_model"]) - self.assertEqual(len(suite.models), 0, "Should have no models") - - # Should not crash when running on empty list - results = suite.test_model_correctness() - self.assertEqual(len(results), 0, "Should have no results") + """Test suite raises error for nonexistent model.""" + with self.assertRaises(ValueError) as context: + _ = ModelSuite(filter=["nonexistent_model"]) + self.assertIn("No models found", str(context.exception)) if __name__ == "__main__": diff --git a/test/test_model_suite_integration.py b/test/test_model_suite_integration.py index 43d04c4a..87bb6422 100644 --- a/test/test_model_suite_integration.py +++ b/test/test_model_suite_integration.py @@ -67,7 +67,7 @@ def test_filtering_by_model(self): "directory", "--ops-directory", "generated_kernels", - "--ops", + "--model-filter", "toy_core_ops", "--disable-output-logs", "--log-level", @@ -117,7 +117,7 @@ def test_empty_filter(self): "directory", "--ops-directory", "generated_kernels", - "--ops", + "--model-filter", "nonexistent_model", "--disable-output-logs", "--log-level", @@ -128,9 +128,36 @@ def test_empty_filter(self): timeout=60, ) - # Should succeed but with no models - self.assertEqual(result.returncode, 0, "Should succeed with empty filter") - self.assertIn("No model tests were run", result.stdout) + # Should fail because explicitly requested model not found + self.assertNotEqual(result.returncode, 0, "Should fail with nonexistent filter") + + def test_ops_filter_rejected(self): + """Test that --ops filter is rejected for model suite.""" + result = subprocess.run( + [ + "python", + "-m", + "BackendBench.scripts.main", + "--suite", + "model", + "--backend", + "directory", + "--ops-directory", + "generated_kernels", + "--ops", + "toy_core_ops", + "--disable-output-logs", + "--log-level", + "ERROR", + ], + capture_output=True, + text=True, + timeout=30, + ) + + # Should fail with error message about --ops not supported + self.assertNotEqual(result.returncode, 0, "Should fail with --ops") + self.assertIn("--ops filter is not supported for model suite", result.stderr) class TestModelSuiteIntegration(unittest.TestCase): @@ -147,23 +174,20 @@ def test_initialization_variants(self): self.assertEqual(len(suite2.models), 1, "Should load exactly 1 model") self.assertEqual(suite2.models[0]["name"], "toy_core_ops") - # Empty filter - suite3 = ModelSuite(filter=["nonexistent"]) - self.assertEqual(len(suite3.models), 0, "Should load no models with invalid filter") + # Empty filter - should raise error + with self.assertRaises(ValueError) as context: + _ = ModelSuite(filter=["nonexistent"]) + self.assertIn("No models found", str(context.exception)) def test_operator_level_integration(self): """Test that operator-level testing works via __iter__.""" suite = ModelSuite() op_tests = list(suite) - self.assertGreater(len(op_tests), 0, "Should generate operator tests") - - # Verify first test structure - if len(op_tests) > 0: - first_test = op_tests[0] - self.assertTrue(hasattr(first_test, "op")) - self.assertTrue(hasattr(first_test, "correctness_tests")) - self.assertTrue(hasattr(first_test, "performance_tests")) + # Model suite currently returns empty iterator + # Operator extraction from model tracing is not yet implemented + # The suite focuses on full model testing via test_model_correctness() + self.assertEqual(len(op_tests), 0, "Operator extraction not yet implemented") def test_model_level_integration(self): """Test that model-level testing works.""" From 20f18899b4b8e151ebcf5885e9579465bb230991 Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 30 Sep 2025 09:41:13 +0000 Subject: [PATCH 04/16] edits --- BackendBench/eval_model.py | 139 +++++++--- BackendBench/suite/model.py | 400 +++++++-------------------- test/test_model_suite_correctness.py | 200 +++----------- 3 files changed, 237 insertions(+), 502 deletions(-) diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py index 7783a5c3..d01fb9e1 100644 --- a/BackendBench/eval_model.py +++ b/BackendBench/eval_model.py @@ -109,6 +109,102 @@ def eval_model_correctness_test( ) +def _get_input_tensor(args: List[Any], kwargs: Dict[str, Any]) -> torch.Tensor: + """Extract input tensor from args or kwargs. + + Args: + args: Positional arguments list + kwargs: Keyword arguments dict + + Returns: + Input tensor if found, None otherwise + """ + if args and isinstance(args[0], torch.Tensor): + return args[0] + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): + return kwargs["x"] + return None + + +def _move_model_to_input_device( + model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any] +) -> torch.nn.Module: + """Move model to the same device as input tensor. + + Args: + model: Model to move + args: Positional arguments list + kwargs: Keyword arguments dict + + Returns: + Model on input device (or original model if no input tensor found) + """ + input_tensor = _get_input_tensor(args, kwargs) + if input_tensor is not None: + device = input_tensor.device + model = model.to(device) + return model + + +def _ensure_input_requires_grad( + args: List[Any], kwargs: Dict[str, Any] +) -> Tuple[List[Any], Dict[str, Any]]: + """Ensure input tensor has requires_grad=True for gradient computation. + + Args: + args: Positional arguments list + kwargs: Keyword arguments dict + + Returns: + Updated (args, kwargs) with input tensor requiring gradients + """ + if args and isinstance(args[0], torch.Tensor): + x = args[0] + if not x.requires_grad: + x = x.clone().detach().requires_grad_(True) + args = [x] + list(args[1:]) + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): + x = kwargs["x"] + if not x.requires_grad: + x = x.clone().detach().requires_grad_(True) + kwargs["x"] = x + + return args, kwargs + + +def _collect_gradients( + model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any] +) -> List[torch.Tensor]: + """Collect gradients from input and model parameters. + + Args: + model: Model with computed gradients + args: Positional arguments list + kwargs: Keyword arguments dict + + Returns: + List of gradient tensors [input_grad, param1_grad, ...] + """ + grads = [] + + # Input gradient - check both args and kwargs + input_grad = None + if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None: + input_grad = args[0].grad + elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None: + input_grad = kwargs["x"].grad + + if input_grad is not None: + grads.append(input_grad.clone()) + + # Parameter gradients + for param in model.parameters(): + if param.grad is not None: + grads.append(param.grad.clone()) + + return grads + + def _run_model( model_class: type, model_config: Dict[str, Any], @@ -154,29 +250,11 @@ def _run_model( model = model_class(**init_args) model.train() - # Move model to same device as input (typically CUDA) - # Check both args and kwargs for tensor - input_tensor = None - if args and isinstance(args[0], torch.Tensor): - input_tensor = args[0] - elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): - input_tensor = kwargs["x"] - - if input_tensor is not None: - device = input_tensor.device - model = model.to(device) + # Move model to same device as input + model = _move_model_to_input_device(model, args, kwargs) # Ensure input has requires_grad for gradient computation - if args and isinstance(args[0], torch.Tensor): - x = args[0] - if not x.requires_grad: - x = x.clone().detach().requires_grad_(True) - args = [x] + list(args[1:]) - elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): - x = kwargs["x"] - if not x.requires_grad: - x = x.clone().detach().requires_grad_(True) - kwargs["x"] = x + args, kwargs = _ensure_input_requires_grad(args, kwargs) # Run forward + backward with or without backend if backend_enabled: @@ -194,22 +272,7 @@ def _run_model( loss = output.sum() loss.backward() - # Collect gradients: [input_grad, param1_grad, ...] - grads = [] - - # Input gradient - check both args and kwargs - input_grad = None - if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None: - input_grad = args[0].grad - elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None: - input_grad = kwargs["x"].grad - - if input_grad is not None: - grads.append(input_grad.clone()) - - # Parameter gradients - for param in model.parameters(): - if param.grad is not None: - grads.append(param.grad.clone()) + # Collect gradients + grads = _collect_gradients(model, args, kwargs) return output.detach(), grads diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index fa502e75..f67782a0 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -5,25 +5,24 @@ # LICENSE file in the root directory of this source tree. """ -Model Suite for testing toy models against backends. +Model Suite for testing operators traced from toy models. -This suite extends TorchBenchTestSuite to provide two testing approaches: -1. Operator-level testing (via __iter__, inherited infrastructure) -2. Model-level correctness testing (via test_model_correctness, new functionality) +This suite extends TorchBenchTestSuite by tracing model execution +to extract operators, then filtering the TorchBench dataset to only +include those operators. """ import importlib.util import json import logging import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import torch from BackendBench.data_loaders import load_ops_from_source, op_list_to_benchmark_dict -from BackendBench.utils import deserialize_args -from .torchbench import TorchBenchOpTest, TorchBenchTestSuite +from .torchbench import TorchBenchTestSuite logger = logging.getLogger(__name__) @@ -36,7 +35,7 @@ def _get_torchbench_ops(): global _TORCHBENCH_OPS_CACHE if _TORCHBENCH_OPS_CACHE is None: ops_list = load_ops_from_source(source=None, format="parquet") - _TORCHBENCH_OPS_CACHE = set(op_list_to_benchmark_dict(ops_list).keys()) + _TORCHBENCH_OPS_CACHE = op_list_to_benchmark_dict(ops_list) return _TORCHBENCH_OPS_CACHE @@ -66,7 +65,7 @@ def load_toy_models( continue model_dir = os.path.join(toy_models_dir, model_name) - if not os.path.isdir(model_dir): + if not os.isdir(model_dir): continue # Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json @@ -76,14 +75,12 @@ def load_toy_models( # Check both files exist if not os.path.exists(model_file): if filter is not None and model_name in filter: - # If the model was explicitly requested but not found, raise an error raise FileNotFoundError(f"Model file not found: {model_file}") logger.warning(f"Model file not found: {model_file}") continue if not os.path.exists(config_file): if filter is not None and model_name in filter: - # If the model was explicitly requested but not found, raise an error raise FileNotFoundError(f"Config file not found: {config_file}") logger.warning(f"Config file not found: {config_file}") continue @@ -114,20 +111,11 @@ def load_toy_models( logger.error(f"No model class found in {model_file}") continue - # Generate runtime seed if required - if config.get("model_config", {}).get("requires_init_seed", False): - import random - - runtime_seed = random.randint(0, 2**31 - 1) - config["model_config"]["runtime_seed"] = runtime_seed - logger.debug(f"Generated runtime seed {runtime_seed} for {model_name}") - models.append({"name": model_name, "class": model_class, "config": config}) logger.info(f"Loaded model: {model_name}") except Exception as e: if filter is not None and model_name in filter: - # If the model was explicitly requested but failed to load, raise an error raise RuntimeError(f"Failed to load model {model_name}: {e}") logger.error(f"Failed to load {model_name}: {e}") continue @@ -139,323 +127,123 @@ def load_toy_models( return models -def _create_op_test(op_name: str, inputs: List[str]): - """Create a TorchBenchOpTest for a specific operator. +def _trace_model_ops(model_class, model_config: Dict[str, Any]) -> Set[str]: + """Trace model execution to extract operator names. Args: - op_name: Operator name in aten format (e.g., "aten.conv2d.default") - inputs: List of serialized input strings + model_class: Model class to instantiate + model_config: Model configuration dict with init_args and model_tests Returns: - TorchBenchOpTest instance compatible with existing evaluation infrastructure - - Raises: - ValueError: If the op is not in the torchbench dataset + Set of operator names in aten format (e.g., "aten.conv2d.default") """ - # Check that the op is in the torchbench dataset - torchbench_ops = _get_torchbench_ops() - if op_name not in torchbench_ops: - raise ValueError( - f"Operator {op_name} is not in the torchbench dataset. " - f"Only ops from the torchbench dataset can be tested." - ) + import torch._dynamo as dynamo - # Create the test with serialized inputs - return TorchBenchOpTest(op_name, inputs, topn=None) + from BackendBench.utils import deserialize_args + init_args = model_config.get("init_args", {}) + model = model_class(**init_args) + model.eval() -class FullModelTest: - """Complete model forward/backward testing. - - This class handles running a model with a specific test configuration - in both eager mode and backend mode, then comparing the results. - """ - - def __init__( - self, - model_name: str, - model_class, - model_config: Dict[str, Any], - test_name: str, - test_args: str, - ): - """Initialize FullModelTest. - - Args: - model_name: Name of the model being tested - model_class: Model class to instantiate - model_config: Model configuration dict with init_args - test_name: Name of this test configuration - test_args: Serialized arguments string for forward pass - """ - self.model_name = model_name - self.model_class = model_class - self.model_config = model_config - self.test_name = test_name - self.test_args = test_args - - def run_with_backend(self, backend_enabled: bool, kernel_dir: str = None) -> tuple: - """Run model with backend enabled or disabled. + # Get first test input to trace with + model_tests = model_config.get("model_tests", {}) + if not model_tests: + raise ValueError("No model_tests found in config") - Args: - backend_enabled: If True, use BackendBench context manager to enable backend - kernel_dir: Optional directory containing kernels (for backend mode) + first_test = next(iter(model_tests.values())) + args, kwargs = deserialize_args(first_test) - Returns: - Tuple of (output, gradients) where: - - output: Model output tensor (detached) - - gradients: List of gradient tensors [input_grad, param1_grad, param2_grad, ...] - """ - import BackendBench - - # Deserialize test arguments - args, kwargs = deserialize_args(self.test_args) - - # Extract model initialization args - init_args = self.model_config.get("init_args", {}).copy() - - # Handle seed: use runtime_seed if required, otherwise use seed from init_args - if self.model_config.get("requires_init_seed", False): - # Use the generated runtime seed - seed = self.model_config["runtime_seed"] - init_args["seed"] = seed - else: - # Use seed from init_args or default - seed = init_args.get("seed", 42) - - # Set seed for deterministic behavior - torch.manual_seed(seed) - - # Create fresh model instance - model = self.model_class(**init_args) - model.train() - - # Move model to same device as input (typically CUDA) - # Check both args and kwargs for tensor - input_tensor = None - if args and isinstance(args[0], torch.Tensor): - input_tensor = args[0] - elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): - input_tensor = kwargs["x"] - - if input_tensor is not None: - device = input_tensor.device - model = model.to(device) - - # Ensure input has requires_grad for gradient computation - if args and isinstance(args[0], torch.Tensor): - x = args[0] - if not x.requires_grad: - x = x.clone().detach().requires_grad_(True) - args = [x] + list(args[1:]) - elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): - x = kwargs["x"] - if not x.requires_grad: - x = x.clone().detach().requires_grad_(True) - kwargs["x"] = x - - # Run forward + backward with or without backend - if backend_enabled: - # Use context manager to enable backend - if kernel_dir is None: - # Default to generated_kernels directory - kernel_dir = os.path.join(os.getcwd(), "generated_kernels") - - with BackendBench.BackendBench.enable(kernel_dir=kernel_dir): - output = model(*args, **kwargs) - loss = output.sum() - loss.backward() - else: - # Run in eager mode (no backend) - output = model(*args, **kwargs) - loss = output.sum() - loss.backward() - - # Collect gradients: [input_grad, param1_grad, param2_grad, ...] - grads = [] - - # Input gradient - check both args and kwargs - input_grad = None - if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None: - input_grad = args[0].grad - elif ( - "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None - ): - input_grad = kwargs["x"].grad - - if input_grad is not None: - grads.append(input_grad.clone()) - - # Parameter gradients - for param in model.parameters(): - if param.grad is not None: - grads.append(param.grad.clone()) - - return output.detach(), grads - - def test_correctness(self, atol=1e-6, rtol=1e-5, kernel_dir: str = None) -> bool: - """Test numerical correctness by comparing eager vs backend execution. + # Trace the model to extract ops + ops = set() - Args: - atol: Absolute tolerance for torch.allclose - rtol: Relative tolerance for torch.allclose - kernel_dir: Optional directory containing kernels + def capture_ops(gm, example_inputs): + for node in gm.graph.nodes: + if node.op == "call_function": + target = node.target + if hasattr(target, "__module__") and "torch.ops" in target.__module__: + ops.add(str(target)) + return gm - Returns: - True if eager and backend produce matching results, False otherwise - """ + with torch.no_grad(): try: - # Run in eager mode - eager_out, eager_grads = self.run_with_backend(False, kernel_dir=kernel_dir) - - # Run with backend - backend_out, backend_grads = self.run_with_backend(True, kernel_dir=kernel_dir) - - # Compare outputs - if not torch.allclose(eager_out, backend_out, atol=atol, rtol=rtol): - logger.debug(f"{self.model_name}::{self.test_name}: Output mismatch") - return False - - # Compare number of gradients - if len(eager_grads) != len(backend_grads): - logger.debug( - f"{self.model_name}::{self.test_name}: " - f"Gradient count mismatch ({len(eager_grads)} vs {len(backend_grads)})" - ) - return False - - # Compare each gradient - for i, (eager_grad, backend_grad) in enumerate(zip(eager_grads, backend_grads)): - if not torch.allclose(eager_grad, backend_grad, atol=atol, rtol=rtol): - logger.debug(f"{self.model_name}::{self.test_name}: Gradient {i} mismatch") - return False - - return True - + compiled_model = dynamo.optimize(capture_ops)(model) + compiled_model(*args, **kwargs) except Exception as e: - logger.error(f"{self.model_name}::{self.test_name}: Correctness test failed: {e}") - return False + logger.warning(f"Failed to trace model: {e}") + + return ops class ModelSuite(TorchBenchTestSuite): - """Model Suite extending TorchBenchTestSuite. + """Model Suite that filters TorchBench operators based on model tracing. - Provides two testing approaches: - 1. Operator-level testing via __iter__() (inherited infrastructure) - 2. Model-level correctness testing via test_model_correctness() (Model Suite specific) + This suite traces model execution to extract operators, then creates + a filtered TorchBench suite containing only those operators. """ def __init__( - self, name: str = "model", filter: Optional[List[str]] = None, models_dir: str = None + self, + name: str = "model", + filter: Optional[List[str]] = None, + models_dir: str = None, + topn: Optional[int] = None, ): """Initialize ModelSuite. Args: name: Suite name (default: "model") - filter: Optional list of model names to test + filter: Optional list of model names to load models_dir: Optional directory for models (default: "BackendBench/suite/models") + topn: Optional limit on number of tests per operator """ - # Don't call super().__init__() with parameters since TorchBenchTestSuite - # expects different arguments. Just initialize the base object. - super(TorchBenchTestSuite, self).__init__() - - self.name = name - # Default to models under suite/models if models_dir is None: models_dir = os.path.join(os.path.dirname(__file__), "models") - self.models = load_toy_models(toy_models_dir=models_dir, filter=filter) - logger.info(f"ModelSuite: {len(self.models)} models loaded from {models_dir}") - - def __iter__(self): - """Yield operator tests from all models (TorchBench approach). - - This method extracts operators by tracing model execution, then creates - TorchBenchOpTest instances for testing via the inherited infrastructure. - """ - # Trace each model to extract operators - - # For now, return empty iterator since operator extraction from model tracing - # is not yet implemented. The model suite focuses on full model testing via - # test_model_correctness() method. - return iter([]) - - def test_model_correctness(self, kernel_dir: str = None) -> Dict[str, Dict[str, bool]]: - """Test full model correctness for all models and configurations. - - This method runs each model with each test configuration, comparing - eager mode vs backend mode to verify numerical correctness. - - Args: - kernel_dir: Optional directory containing kernels for backend - - Returns: - Dictionary mapping model_name -> {config_name -> bool} - where bool indicates if the test passed - """ - results = {} - - for model in self.models: - model_results = {} - - # Test each configuration for this model - if "model_tests" not in model["config"]: - logger.warning(f"Model {model['name']} has no model_tests in config") - continue - - # model_tests is a dict mapping test_name -> serialized_args - for test_name, test_args in model["config"]["model_tests"].items(): - test = FullModelTest( - model_name=model["name"], - model_class=model["class"], - model_config=model["config"].get("model_config", {}), - test_name=test_name, - test_args=test_args, - ) - - is_correct = test.test_correctness(kernel_dir=kernel_dir) - model_results[test_name] = is_correct - - status = "PASS" if is_correct else "FAIL" - logger.info(f"{model['name']}::{test_name}: {status}") - - results[model["name"]] = model_results - - return results + # Load models + models = load_toy_models(toy_models_dir=models_dir, filter=filter) + logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}") + + # Trace models to extract operators + model_ops = set() + for model in models: + try: + ops = _trace_model_ops(model["class"], model["config"]) + model_ops.update(ops) + logger.info(f"Model {model['name']}: Found {len(ops)} operators") + except Exception as e: + logger.warning(f"Failed to trace model {model['name']}: {e}") + + logger.info(f"ModelSuite: Total {len(model_ops)} unique operators across all models") + + # Get torchbench ops and filter + torchbench_ops = _get_torchbench_ops() + + # Convert model ops to the format used in torchbench (strip ) + # Example: "" -> "aten.conv2d.default" + filtered_ops = {} + for op_name, op_inputs in torchbench_ops.items(): + # Check if any model op matches this torchbench op + # Model ops from dynamo are like "aten.conv2d.default" + if any(model_op in op_name for model_op in model_ops): + filtered_ops[op_name] = op_inputs + + if not filtered_ops: + raise ValueError( + f"No operators from models found in TorchBench dataset. " + f"Model operators: {model_ops}" + ) + + logger.info( + f"ModelSuite: Filtered to {len(filtered_ops)} operators " + f"(from {len(torchbench_ops)} total)" + ) - def print_model_correctness_results(self, results: Dict[str, Dict[str, bool]]): - """Print formatted model correctness results. + # Initialize parent class with filtered ops + self.name = name + self.topn = topn + self.optests = filtered_ops - Args: - results: Dictionary mapping model_name -> {test_name -> bool} - """ - print("\n" + "=" * 80) - print("FULL MODEL TESTING") - print("=" * 80) - print("\nModel Correctness Results:") - print("-" * 80) - - total_passed = 0 - total_tests = 0 - - for model_name, test_results in results.items(): - passed = sum(1 for result in test_results.values() if result) - total = len(test_results) - total_passed += passed - total_tests += total - percentage = (passed / total * 100) if total > 0 else 0 - print(f" {model_name}: {passed}/{total} configs passed ({percentage:.1f}%)") - - # Show individual config results - for config_name, is_correct in test_results.items(): - status = "✓ PASS" if is_correct else "✗ FAIL" - print(f" {config_name}: {status}") - - print("-" * 80) - if total_tests > 0: - overall_percentage = total_passed / total_tests * 100 - print(f"\nModel Suite Score: {total_passed}/{total_tests} ({overall_percentage:.1f}%)") - else: - print("\nNo model tests were run") - print("=" * 80) + # Deduplicate strings in self.optests + for op in self.optests: + self.optests[op] = list(set(self.optests[op])) diff --git a/test/test_model_suite_correctness.py b/test/test_model_suite_correctness.py index 7508c2ac..d5b3e37f 100644 --- a/test/test_model_suite_correctness.py +++ b/test/test_model_suite_correctness.py @@ -5,12 +5,12 @@ # LICENSE file in the root directory of this source tree. """ -Essential tests for Model Suite PR #2: Full Model Testing & Results +Tests for Model Suite: Filtered TorchBench operators from model tracing -This test suite validates the core functionality: -1. FullModelTest class with eager/backend execution -2. Numerical correctness comparison -3. ModelSuite.test_model_correctness() integration +This test suite validates: +1. Model loading from toy_models directory +2. Operator extraction via model tracing +3. ModelSuite creates filtered TorchBench suite """ import logging @@ -18,184 +18,68 @@ import sys import unittest -import torch - sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) -from BackendBench.suite.model import FullModelTest, load_toy_models, ModelSuite +from BackendBench.suite.model import load_toy_models, ModelSuite # Setup logging logging.basicConfig(level=logging.WARNING) -class TestFullModelTest(unittest.TestCase): - """Test FullModelTest class functionality.""" - - @classmethod - def setUpClass(cls): - """Load toy models once for all tests.""" - cls.models = load_toy_models(toy_models_dir="BackendBench/suite/models") - assert len(cls.models) > 0, "Should load at least one model" - cls.model = next(m for m in cls.models if m["name"] == "toy_core_ops") - - def test_initialization(self): - """Test FullModelTest can be instantiated correctly.""" - test_name = list(self.model["config"]["model_tests"].keys())[0] - test_args = self.model["config"]["model_tests"][test_name] - full_test = FullModelTest( - model_name=self.model["name"], - model_class=self.model["class"], - model_config=self.model["config"]["model_config"], - test_name=test_name, - test_args=test_args, - ) +class TestModelLoading(unittest.TestCase): + """Test toy model loading functionality.""" - self.assertEqual(full_test.model_name, self.model["name"]) - self.assertEqual(full_test.model_class, self.model["class"]) - - def test_eager_execution(self): - """Test model runs correctly in eager mode.""" - test_name = "small_batch" - test_args = self.model["config"]["model_tests"][test_name] - full_test = FullModelTest( - self.model["name"], - self.model["class"], - self.model["config"]["model_config"], - test_name, - test_args, - ) + def test_load_models(self): + """Test that models can be loaded from directory.""" + models = load_toy_models(toy_models_dir="BackendBench/suite/models") + self.assertGreater(len(models), 0, "Should load at least one model") - output, grads = full_test.run_with_backend(backend_enabled=False) - - # Verify output shape (batch_size=2 from small_batch config) - expected_shape = torch.Size([2, 8, 4, 4]) - self.assertEqual(output.shape, expected_shape) - - # Verify gradients computed - self.assertGreater(len(grads), 0, "Should compute gradients") - - # Verify all gradients are valid - for grad in grads: - self.assertIsInstance(grad, torch.Tensor) - self.assertFalse(torch.isnan(grad).any(), "No NaN gradients") - self.assertFalse(torch.isinf(grad).any(), "No Inf gradients") - - def test_backend_execution(self): - """Test model runs with backend enabled.""" - test_name = "small_batch" - test_args = self.model["config"]["model_tests"][test_name] - full_test = FullModelTest( - self.model["name"], - self.model["class"], - self.model["config"]["model_config"], - test_name, - test_args, - ) + # Verify model structure + for model in models: + self.assertIn("name", model) + self.assertIn("class", model) + self.assertIn("config", model) - output, grads = full_test.run_with_backend(backend_enabled=True) - - # Verify output shape (batch_size=2 from small_batch config) - expected_shape = torch.Size([2, 8, 4, 4]) - self.assertEqual(output.shape, expected_shape) - - # Verify gradients computed - self.assertGreater(len(grads), 0, "Should compute gradients") - - def test_correctness_comparison(self): - """Test correctness comparison between eager and backend.""" - test_name = "small_batch" - test_args = self.model["config"]["model_tests"][test_name] - full_test = FullModelTest( - self.model["name"], - self.model["class"], - self.model["config"]["model_config"], - test_name, - test_args, + def test_load_specific_model(self): + """Test loading a specific model by name.""" + models = load_toy_models( + toy_models_dir="BackendBench/suite/models", filter=["toy_core_ops"] ) + self.assertEqual(len(models), 1) + self.assertEqual(models[0]["name"], "toy_core_ops") - is_correct = full_test.test_correctness() - - # Result should be a boolean - self.assertIsInstance(is_correct, bool) - - # With existing kernels, test should pass - self.assertTrue(is_correct, "Backend should produce correct results") - - def test_multiple_configs(self): - """Test all model configurations run correctly.""" - # Expected shapes for each config - # Note: Output size is always 4x4 due to adaptive_avg_pool2d([4, 4]) - expected_shapes = { - "small_batch": torch.Size([2, 8, 4, 4]), - "medium_batch": torch.Size([4, 8, 4, 4]), - "large_input": torch.Size([2, 8, 4, 4]), - } - - for test_name, test_args in self.model["config"]["model_tests"].items(): - full_test = FullModelTest( - self.model["name"], - self.model["class"], - self.model["config"]["model_config"], - test_name, - test_args, - ) - - output, grads = full_test.run_with_backend(backend_enabled=False) - - expected_shape = expected_shapes[test_name] - self.assertEqual(output.shape, expected_shape, f"Config {test_name} failed") - self.assertGreater(len(grads), 0, f"Config {test_name} has no gradients") + def test_invalid_filter(self): + """Test that invalid filter raises error.""" + with self.assertRaises(ValueError): + load_toy_models(toy_models_dir="BackendBench/suite/models", filter=["nonexistent"]) class TestModelSuite(unittest.TestCase): - """Test ModelSuite.test_model_correctness() integration.""" + """Test ModelSuite integration with TorchBench.""" - def test_model_correctness_method_exists(self): - """Test that test_model_correctness method exists.""" + def test_suite_initialization(self): + """Test that ModelSuite can be initialized.""" suite = ModelSuite() - self.assertTrue(hasattr(suite, "test_model_correctness")) + self.assertEqual(suite.name, "model") + self.assertIsNotNone(suite.optests) - def test_model_correctness_integration(self): - """Test ModelSuite.test_model_correctness() returns proper results.""" + def test_suite_has_operators(self): + """Test that suite extracts operators from models.""" suite = ModelSuite() - results = suite.test_model_correctness() - - # Verify results structure - self.assertIsInstance(results, dict, "Results should be a dictionary") - self.assertGreater(len(results), 0, "Should have results for at least one model") - - # Verify each model has config results - for model_name, model_results in results.items(): - self.assertIsInstance(model_results, dict, f"{model_name} results should be dict") - self.assertGreater(len(model_results), 0, f"{model_name} should have test configs") + # Should have extracted and filtered operators + self.assertGreater(len(suite.optests), 0, "Should have at least one operator") - # Verify each config result is a boolean - for config_name, is_correct in model_results.items(): - self.assertIsInstance( - is_correct, bool, f"{model_name}::{config_name} should be bool" - ) - - def test_results_aggregation(self): - """Test that results can be aggregated for scoring.""" + def test_suite_iteration(self): + """Test that suite can be iterated (TorchBench interface).""" suite = ModelSuite() - results = suite.test_model_correctness() - - # Calculate aggregate statistics - total_tests = sum(len(model_results) for model_results in results.values()) - total_passed = sum( - sum(1 for result in model_results.values() if result) - for model_results in results.values() - ) - - self.assertGreater(total_tests, 0, "Should have at least one test") - self.assertLessEqual(total_passed, total_tests, "Passed <= Total") - self.assertGreaterEqual(total_passed, 0, "Passed >= 0") + op_tests = list(suite) + # Should have at least one operator test + self.assertGreater(len(op_tests), 0, "Should have at least one operator test") def test_empty_filter(self): """Test suite raises error for nonexistent model.""" - with self.assertRaises(ValueError) as context: + with self.assertRaises(ValueError): _ = ModelSuite(filter=["nonexistent_model"]) - self.assertIn("No models found", str(context.exception)) if __name__ == "__main__": From 2ab91b75e7791e210945c935c057baa62c6bc4ae Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 30 Sep 2025 09:57:16 +0000 Subject: [PATCH 05/16] edits --- BackendBench/suite/model.py | 148 +++++------------- .../ToyCoreOpsModel.json} | 7 + .../ToyCoreOpsModel.py} | 52 +++--- 3 files changed, 71 insertions(+), 136 deletions(-) rename BackendBench/suite/models/{toy_core_ops/toy_core_ops.json => ToyCoreOpsModel/ToyCoreOpsModel.json} (73%) rename BackendBench/suite/models/{toy_core_ops/toy_core_ops.py => ToyCoreOpsModel/ToyCoreOpsModel.py} (86%) diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index f67782a0..612242ce 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -5,20 +5,18 @@ # LICENSE file in the root directory of this source tree. """ -Model Suite for testing operators traced from toy models. +Model Suite for testing operators defined in toy model configs. -This suite extends TorchBenchTestSuite by tracing model execution -to extract operators, then filtering the TorchBench dataset to only -include those operators. +This suite extends TorchBenchTestSuite by reading operator lists from +model configs, validating they exist in the TorchBench dataset, then +filtering to include only those operators. """ import importlib.util import json import logging import os -from typing import Any, Dict, List, Optional, Set - -import torch +from typing import Any, Dict, List, Optional from BackendBench.data_loaders import load_ops_from_source, op_list_to_benchmark_dict @@ -60,10 +58,6 @@ def load_toy_models( raise FileNotFoundError(f"Toy models directory not found: {toy_models_dir}") for model_name in os.listdir(toy_models_dir): - # Apply filter if specified - if filter is not None and model_name not in filter: - continue - model_dir = os.path.join(toy_models_dir, model_name) if not os.isdir(model_dir): continue @@ -74,16 +68,10 @@ def load_toy_models( # Check both files exist if not os.path.exists(model_file): - if filter is not None and model_name in filter: - raise FileNotFoundError(f"Model file not found: {model_file}") - logger.warning(f"Model file not found: {model_file}") - continue + raise FileNotFoundError(f"Model file not found: {model_file}") if not os.path.exists(config_file): - if filter is not None and model_name in filter: - raise FileNotFoundError(f"Config file not found: {config_file}") - logger.warning(f"Config file not found: {config_file}") - continue + raise FileNotFoundError(f"Config file not found: {config_file}") try: # Load config @@ -95,90 +83,29 @@ def load_toy_models( module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - # Find model class (ends with "Model") - model_class = None - for attr_name in dir(module): - attr = getattr(module, attr_name) - if ( - isinstance(attr, type) - and attr_name.endswith("Model") - and hasattr(attr, "forward") - ): - model_class = attr - break - - if model_class is None: - logger.error(f"No model class found in {model_file}") - continue + # Find model class (must match model_name exactly) + if not hasattr(module, model_name): + raise RuntimeError(f"Model class '{model_name}' not found in {model_file}") + + model_class = getattr(module, model_name) + if not (isinstance(model_class, type) and hasattr(model_class, "forward")): + raise RuntimeError(f"'{model_name}' in {model_file} is not a valid model class") models.append({"name": model_name, "class": model_class, "config": config}) logger.info(f"Loaded model: {model_name}") except Exception as e: - if filter is not None and model_name in filter: - raise RuntimeError(f"Failed to load model {model_name}: {e}") - logger.error(f"Failed to load {model_name}: {e}") - continue - - # If a filter was specified but no models were loaded, raise an error - if filter is not None and len(models) == 0: - raise ValueError(f"No models found matching filter: {filter}") + raise RuntimeError(f"Failed to load model {model_name}: {e}") return models -def _trace_model_ops(model_class, model_config: Dict[str, Any]) -> Set[str]: - """Trace model execution to extract operator names. - - Args: - model_class: Model class to instantiate - model_config: Model configuration dict with init_args and model_tests - - Returns: - Set of operator names in aten format (e.g., "aten.conv2d.default") - """ - import torch._dynamo as dynamo - - from BackendBench.utils import deserialize_args - - init_args = model_config.get("init_args", {}) - model = model_class(**init_args) - model.eval() - - # Get first test input to trace with - model_tests = model_config.get("model_tests", {}) - if not model_tests: - raise ValueError("No model_tests found in config") - - first_test = next(iter(model_tests.values())) - args, kwargs = deserialize_args(first_test) - - # Trace the model to extract ops - ops = set() - - def capture_ops(gm, example_inputs): - for node in gm.graph.nodes: - if node.op == "call_function": - target = node.target - if hasattr(target, "__module__") and "torch.ops" in target.__module__: - ops.add(str(target)) - return gm - - with torch.no_grad(): - try: - compiled_model = dynamo.optimize(capture_ops)(model) - compiled_model(*args, **kwargs) - except Exception as e: - logger.warning(f"Failed to trace model: {e}") - - return ops - - class ModelSuite(TorchBenchTestSuite): - """Model Suite that filters TorchBench operators based on model tracing. + """Model Suite that filters TorchBench operators based on model configs. - This suite traces model execution to extract operators, then creates - a filtered TorchBench suite containing only those operators. + This suite reads operator lists from model configs, validates they exist + in the TorchBench dataset, then creates a filtered suite containing only + those operators. """ def __init__( @@ -204,29 +131,38 @@ def __init__( models = load_toy_models(toy_models_dir=models_dir, filter=filter) logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}") - # Trace models to extract operators + # Extract operators from model configs model_ops = set() for model in models: - try: - ops = _trace_model_ops(model["class"], model["config"]) - model_ops.update(ops) - logger.info(f"Model {model['name']}: Found {len(ops)} operators") - except Exception as e: - logger.warning(f"Failed to trace model {model['name']}: {e}") + config_ops = model["config"].get("ops", []) + if not config_ops: + raise ValueError(f"Model {model['name']} has no 'ops' field in config") + model_ops.update(config_ops) + logger.info(f"Model {model['name']}: {len(config_ops)} operators defined in config") logger.info(f"ModelSuite: Total {len(model_ops)} unique operators across all models") # Get torchbench ops and filter torchbench_ops = _get_torchbench_ops() - # Convert model ops to the format used in torchbench (strip ) - # Example: "" -> "aten.conv2d.default" + # Filter torchbench ops to only include those in model configs filtered_ops = {} - for op_name, op_inputs in torchbench_ops.items(): - # Check if any model op matches this torchbench op - # Model ops from dynamo are like "aten.conv2d.default" - if any(model_op in op_name for model_op in model_ops): - filtered_ops[op_name] = op_inputs + unsupported_ops = [] + for model_op in model_ops: + # Find matching torchbench ops + matched = False + for op_name, op_inputs in torchbench_ops.items(): + if model_op in op_name: + filtered_ops[op_name] = op_inputs + matched = True + if not matched: + unsupported_ops.append(model_op) + + # Error out if any ops are not supported by torchbench + if unsupported_ops: + raise ValueError( + f"The following operators are not supported by TorchBench: {unsupported_ops}" + ) if not filtered_ops: raise ValueError( diff --git a/BackendBench/suite/models/toy_core_ops/toy_core_ops.json b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json similarity index 73% rename from BackendBench/suite/models/toy_core_ops/toy_core_ops.json rename to BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json index b9bab0e6..514bc095 100644 --- a/BackendBench/suite/models/toy_core_ops/toy_core_ops.json +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json @@ -8,6 +8,13 @@ "num_groups": 8 } }, + "ops": [ + "aten.convolution.default", + "aten.native_group_norm.default", + "aten.max_pool2d_with_indices.default", + "aten.avg_pool2d.default", + "aten._adaptive_avg_pool2d.default" + ], "model_tests": { "small_batch": "([], {'x': T([2, 3, 32, 32], f32)})", "medium_batch": "([], {'x': T([4, 3, 64, 64], f32)})", diff --git a/BackendBench/suite/models/toy_core_ops/toy_core_ops.py b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py similarity index 86% rename from BackendBench/suite/models/toy_core_ops/toy_core_ops.py rename to BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py index 142774c5..a6ec5531 100644 --- a/BackendBench/suite/models/toy_core_ops/toy_core_ops.py +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py @@ -1,4 +1,8 @@ -#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. """ Toy model that uses core PyTorch operators during training. @@ -46,12 +50,14 @@ class ToyCoreOpsModel(nn.Module): - Adaptive average pooling for _adaptive_avg_pool2d_backward """ - def __init__(self, - in_channels: int = 3, - hidden_channels: int = 32, - out_channels: int = 8, - num_groups: int = 8, - seed: int = 42): + def __init__( + self, + in_channels: int = 3, + hidden_channels: int = 32, + out_channels: int = 8, + num_groups: int = 8, + seed: int = 42, + ): """ Initialize the ToyCoreOpsModel. @@ -82,37 +88,23 @@ def __init__(self, # First convolution block (triggers convolution_backward) self.conv1 = nn.Conv2d( - in_channels=in_channels, - out_channels=hidden_channels, - kernel_size=3, - padding=1 + in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, padding=1 ) # First group normalization (triggers native_group_norm_backward) - self.group_norm1 = nn.GroupNorm( - num_groups=num_groups, - num_channels=hidden_channels - ) + self.group_norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=hidden_channels) # Second convolution block (triggers convolution_backward again) self.conv2 = nn.Conv2d( - in_channels=hidden_channels, - out_channels=hidden_channels, - kernel_size=3, - padding=1 + in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, padding=1 ) # Second group normalization (triggers native_group_norm_backward again) - self.group_norm2 = nn.GroupNorm( - num_groups=num_groups, - num_channels=hidden_channels - ) + self.group_norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=hidden_channels) # Final convolution for output (triggers convolution_backward again) self.conv_out = nn.Conv2d( - in_channels=hidden_channels, - out_channels=out_channels, - kernel_size=1 + in_channels=hidden_channels, out_channels=out_channels, kernel_size=1 ) # Initialize weights deterministically @@ -130,7 +122,7 @@ def _initialize_weights(self, seed: int): for module in self.modules(): if isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.GroupNorm): @@ -193,7 +185,7 @@ def main(): hidden_channels=32, out_channels=8, num_groups=8, - seed=42 # Deterministic initialization + seed=42, # Deterministic initialization ) # Create sample input @@ -223,11 +215,11 @@ def main(): total_params = len(list(model.parameters())) print(f"✓ Gradients computed for {grad_count}/{total_params} parameters") - print(f"\n✓ Model demonstration completed successfully!") + print("\n✓ Model demonstration completed successfully!") print("This model is ready to be used with the Model Suite for testing core operators.") return model if __name__ == "__main__": - main() \ No newline at end of file + main() From 36e44c270171731d77e3469d0334af639fcc894c Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 30 Sep 2025 10:08:28 +0000 Subject: [PATCH 06/16] edits --- BackendBench/scripts/main.py | 11 ++-- BackendBench/suite/model.py | 114 ++++++++++++++++++++++++++++++++--- 2 files changed, 112 insertions(+), 13 deletions(-) diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 4f5d2bb4..1bfb368b 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -316,10 +316,13 @@ def cli( # Add full model testing for model suite if suite.name == "model": - # Pass ops_directory as kernel_dir for directory backend - kernel_dir = ops_directory if backend_name == "directory" else None - model_results = suite.test_model_correctness(kernel_dir=kernel_dir) - suite.print_model_correctness_results(model_results) + logger.info("\n" + "=" * 60) + logger.info("MODEL EVALUATION") + logger.info("=" * 60) + for model in suite.models: + results = suite.eval_model(model, backend) + suite.print_results(results) + logger.info("=" * 60) command = "python -m BackendBench.scripts.main " + " ".join(sys.argv[1:]) diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index 612242ce..a3ec9ab2 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -24,17 +24,11 @@ logger = logging.getLogger(__name__) -# Cache for torchbench ops to avoid reloading -_TORCHBENCH_OPS_CACHE = None - def _get_torchbench_ops(): - """Get list of available ops from torchbench dataset (cached).""" - global _TORCHBENCH_OPS_CACHE - if _TORCHBENCH_OPS_CACHE is None: - ops_list = load_ops_from_source(source=None, format="parquet") - _TORCHBENCH_OPS_CACHE = op_list_to_benchmark_dict(ops_list) - return _TORCHBENCH_OPS_CACHE + """Get list of available ops from torchbench dataset.""" + ops_list = load_ops_from_source(source=None, format="parquet") + return op_list_to_benchmark_dict(ops_list) def load_toy_models( @@ -183,3 +177,105 @@ def __init__( # Deduplicate strings in self.optests for op in self.optests: self.optests[op] = list(set(self.optests[op])) + + # Store loaded models for evaluation + self.models = models + + def eval_model(self, model_dict: Dict[str, Any], backend) -> Dict[str, Any]: + """Run evaluation on a single model. + + Args: + model_dict: Dictionary with keys 'name', 'class', 'config' + backend: Backend to use for evaluation + + Returns: + Dictionary with evaluation results including correctness and performance + """ + from BackendBench.eval_model import eval_model_correctness_test + + model_class = model_dict["class"] + model_name = model_dict["name"] + config = model_dict["config"] + + # Extract model configuration and tests + model_config = config.get("model_config", {}) + model_tests = config.get("model_tests", {}) + + if not model_tests: + return { + "model_name": model_name, + "passed": False, + "error": "No model_tests found in config", + "test_results": [], + } + + # Get kernel_dir from backend if available + kernel_dir = getattr(backend, "ops_dir", None) + + # Run each test + test_results = [] + for test_name, test_args in model_tests.items(): + result = eval_model_correctness_test( + model_name=model_name, + model_class=model_class, + model_config=model_config, + test_name=test_name, + test_args=test_args, + kernel_dir=kernel_dir, + ) + test_results.append(result) + + # Aggregate results + all_passed = all(r.is_correct for r in test_results) + num_passed = sum(1 for r in test_results if r.is_correct) + num_total = len(test_results) + + return { + "model_name": model_name, + "passed": all_passed, + "num_passed": num_passed, + "num_total": num_total, + "test_results": test_results, + } + + def print_results(self, results: Dict[str, Any]) -> None: + """Print model evaluation results. + + Args: + results: Dictionary with evaluation results from eval_model + """ + model_name = results.get("model_name", "Unknown") + passed = results.get("passed", False) + num_passed = results.get("num_passed", 0) + num_total = results.get("num_total", 0) + + logger.info(f"\nModel: {model_name}") + logger.info( + f"Status: {'✓ Passed' if passed else '✗ Failed'} ({num_passed}/{num_total} tests)" + ) + + # If there's a top-level error (no tests run) + if "error" in results: + logger.info(f"Error: {results['error']}") + return + + # Print details for each test + test_results = results.get("test_results", []) + for result in test_results: + status = "✓" if result.is_correct else "✗" + logger.info(f" {status} {result.test_name}") + + if not result.is_correct: + if result.error_msg: + logger.info(f" Error: {result.error_msg}") + else: + # Show what failed + if not result.output_match: + logger.info(" Output mismatch") + if not result.gradients_match: + logger.info(f" Gradient mismatch ({result.num_gradients} gradients)") + else: + # Show success details + logger.info( + f" Output match: ✓ Gradients match: ✓ ({result.num_gradients} gradients)" + ) From 228d8e195cb7834816f1dd9df281bf412d803a1c Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 02:52:36 +0000 Subject: [PATCH 07/16] edits --- BackendBench/eval_model.py | 9 +- BackendBench/suite/model.py | 115 +++--- .../ToyCoreOpsModel/ToyCoreOpsModel.json | 23 +- BackendBench/suite/torchbench.py | 10 +- test/test_model_ops_coverage.py | 209 +++++++++++ test/test_model_ops_filter.py | 329 ++++++++++++++++++ test/test_model_suite_correctness.py | 12 +- test/test_model_suite_integration.py | 56 ++- 8 files changed, 640 insertions(+), 123 deletions(-) create mode 100644 test/test_model_ops_coverage.py create mode 100644 test/test_model_ops_filter.py diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py index d01fb9e1..0ea87896 100644 --- a/BackendBench/eval_model.py +++ b/BackendBench/eval_model.py @@ -7,13 +7,13 @@ """Model-level evaluation utilities for testing full model correctness.""" import logging -import os import traceback from dataclasses import dataclass from typing import Any, Dict, List, Tuple import torch +import BackendBench from BackendBench.utils import deserialize_args logger = logging.getLogger(__name__) @@ -210,7 +210,7 @@ def _run_model( model_config: Dict[str, Any], test_args: str, backend_enabled: bool, - kernel_dir: str = None, + kernel_dir: str = "generated_kernels", ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Run model with or without backend enabled. @@ -226,7 +226,6 @@ def _run_model( - output: Model output tensor (detached) - gradients: List of gradient tensors [input_grad, param1_grad, ...] """ - import BackendBench # Deserialize test arguments args, kwargs = deserialize_args(test_args) @@ -258,10 +257,6 @@ def _run_model( # Run forward + backward with or without backend if backend_enabled: - # Use context manager to enable backend - if kernel_dir is None: - kernel_dir = os.path.join(os.getcwd(), "generated_kernels") - with BackendBench.BackendBench.enable(kernel_dir=kernel_dir): output = model(*args, **kwargs) loss = output.sum() diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index a3ec9ab2..450c2ce4 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -18,26 +18,20 @@ import os from typing import Any, Dict, List, Optional -from BackendBench.data_loaders import load_ops_from_source, op_list_to_benchmark_dict +from BackendBench.eval_model import eval_model_correctness_test from .torchbench import TorchBenchTestSuite logger = logging.getLogger(__name__) -def _get_torchbench_ops(): - """Get list of available ops from torchbench dataset.""" - ops_list = load_ops_from_source(source=None, format="parquet") - return op_list_to_benchmark_dict(ops_list) - - -def load_toy_models( - toy_models_dir: str = "models", filter: Optional[List[str]] = None +def load_models( + models_dir: str = "models", filter: Optional[List[str]] = None ) -> List[Dict[str, Any]]: """Load models using strict naming convention: folder_name/folder_name.py + folder_name.json Args: - toy_models_dir: Directory containing toy models (default: "models") + models_dir: Directory containing models (default: "models") filter: Optional list of model names to load. If None, loads all models. Returns: @@ -48,12 +42,16 @@ def load_toy_models( """ models = [] - if not os.path.exists(toy_models_dir): - raise FileNotFoundError(f"Toy models directory not found: {toy_models_dir}") + if not os.path.exists(models_dir): + raise FileNotFoundError(f"Models directory not found: {models_dir}") + + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue - for model_name in os.listdir(toy_models_dir): - model_dir = os.path.join(toy_models_dir, model_name) - if not os.isdir(model_dir): + # Skip if not in filter + if filter is not None and model_name not in filter: continue # Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json @@ -91,6 +89,9 @@ def load_toy_models( except Exception as e: raise RuntimeError(f"Failed to load model {model_name}: {e}") + if filter is not None and len(models) == 0: + raise ValueError(f"No models found matching filter: {filter}") + return models @@ -106,7 +107,6 @@ def __init__( self, name: str = "model", filter: Optional[List[str]] = None, - models_dir: str = None, topn: Optional[int] = None, ): """Initialize ModelSuite. @@ -114,72 +114,52 @@ def __init__( Args: name: Suite name (default: "model") filter: Optional list of model names to load - models_dir: Optional directory for models (default: "BackendBench/suite/models") topn: Optional limit on number of tests per operator """ - # Default to models under suite/models - if models_dir is None: - models_dir = os.path.join(os.path.dirname(__file__), "models") + models_dir = os.path.join(os.path.dirname(__file__), "models") # Load models - models = load_toy_models(toy_models_dir=models_dir, filter=filter) + models = load_models(models_dir=models_dir, filter=filter) logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}") + model_ops = self.get_model_ops(models) + filter = list(model_ops) + # Store loaded models for evaluation + self.models = models + + self._initialize_torchbench_suite(name, None, filter, topn, False) + def get_model_ops(self, models: List[Dict[str, Any]]) -> List[str]: # Extract operators from model configs model_ops = set() for model in models: - config_ops = model["config"].get("ops", []) + config_ops = model["config"].get("ops") if not config_ops: raise ValueError(f"Model {model['name']} has no 'ops' field in config") - model_ops.update(config_ops) - logger.info(f"Model {model['name']}: {len(config_ops)} operators defined in config") - logger.info(f"ModelSuite: Total {len(model_ops)} unique operators across all models") - - # Get torchbench ops and filter - torchbench_ops = _get_torchbench_ops() - - # Filter torchbench ops to only include those in model configs - filtered_ops = {} - unsupported_ops = [] - for model_op in model_ops: - # Find matching torchbench ops - matched = False - for op_name, op_inputs in torchbench_ops.items(): - if model_op in op_name: - filtered_ops[op_name] = op_inputs - matched = True - if not matched: - unsupported_ops.append(model_op) - - # Error out if any ops are not supported by torchbench - if unsupported_ops: - raise ValueError( - f"The following operators are not supported by TorchBench: {unsupported_ops}" - ) - - if not filtered_ops: - raise ValueError( - f"No operators from models found in TorchBench dataset. " - f"Model operators: {model_ops}" - ) - - logger.info( - f"ModelSuite: Filtered to {len(filtered_ops)} operators " - f"(from {len(torchbench_ops)} total)" - ) + # Support both list format (legacy) and dict format (forward/backward) + if isinstance(config_ops, list): + # Legacy format: ops is a flat list + ops_list = config_ops + elif isinstance(config_ops, dict): + # New format: ops is a dict with 'forward' and 'backward' keys + ops_list = [] + if "forward" in config_ops: + ops_list.extend(config_ops["forward"]) + if "backward" in config_ops: + ops_list.extend(config_ops["backward"]) + else: + raise ValueError( + f"Model {model['name']}: 'ops' must be either a list or a dict with 'forward'/'backward' keys" + ) - # Initialize parent class with filtered ops - self.name = name - self.topn = topn - self.optests = filtered_ops + if not ops_list: + raise ValueError(f"Model {model['name']}: 'ops' field is empty") - # Deduplicate strings in self.optests - for op in self.optests: - self.optests[op] = list(set(self.optests[op])) + model_ops.update(ops_list) + logger.info(f"Model {model['name']}: {len(ops_list)} operators defined in config") - # Store loaded models for evaluation - self.models = models + logger.info(f"ModelSuite: Total {len(model_ops)} unique operators across all models") + return model_ops def eval_model(self, model_dict: Dict[str, Any], backend) -> Dict[str, Any]: """Run evaluation on a single model. @@ -191,7 +171,6 @@ def eval_model(self, model_dict: Dict[str, Any], backend) -> Dict[str, Any]: Returns: Dictionary with evaluation results including correctness and performance """ - from BackendBench.eval_model import eval_model_correctness_test model_class = model_dict["class"] model_name = model_dict["name"] diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json index 514bc095..0b5e4c9a 100644 --- a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json @@ -8,13 +8,22 @@ "num_groups": 8 } }, - "ops": [ - "aten.convolution.default", - "aten.native_group_norm.default", - "aten.max_pool2d_with_indices.default", - "aten.avg_pool2d.default", - "aten._adaptive_avg_pool2d.default" - ], + "ops": { + "forward": [ + "aten.convolution.default", + "aten.native_group_norm.default", + "aten.max_pool2d_with_indices.default", + "aten.avg_pool2d.default", + "aten._adaptive_avg_pool2d.default" + ], + "backward": [ + "aten.convolution_backward.default", + "aten.native_group_norm_backward.default", + "aten.max_pool2d_with_indices_backward.default", + "aten.avg_pool2d_backward.default", + "aten._adaptive_avg_pool2d_backward.default" + ] + }, "model_tests": { "small_batch": "([], {'x': T([2, 3, 32, 32], f32)})", "medium_batch": "([], {'x': T([4, 3, 64, 64], f32)})", diff --git a/BackendBench/suite/torchbench.py b/BackendBench/suite/torchbench.py index 2ee3d698..29597636 100644 --- a/BackendBench/suite/torchbench.py +++ b/BackendBench/suite/torchbench.py @@ -78,6 +78,13 @@ def __init__( filter=None, topn=None, check_overhead_dominated_ops=False, + ): + self._initialize_torchbench_suite( + name, filename, filter, topn, check_overhead_dominated_ops + ) + + def _initialize_torchbench_suite( + self, name, filename, filter, topn, check_overhead_dominated_ops ): self.name = name self.topn = topn @@ -87,9 +94,6 @@ def __init__( format="auto", # Auto-detect based on file extension filter=filter, ) - if check_overhead_dominated_ops: - # Only include ops which are overhead dominated (this is useful as a performance canary) - ops_list = [op for op in ops_list if op.get("is_overhead_dominated_op", False)] # Convert to dictionary format using utility function self.optests = op_list_to_benchmark_dict(ops_list) diff --git a/test/test_model_ops_coverage.py b/test/test_model_ops_coverage.py new file mode 100644 index 00000000..6b7337ac --- /dev/null +++ b/test/test_model_ops_coverage.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unit test to verify that models actually invoke all operators declared in their configs. + +This test validates that: +1. Forward pass invokes all operators in config["ops"]["forward"] +2. Backward pass invokes all operators in config["ops"]["backward"] +3. Clear error messages indicate which operators are missing per model +""" + +import os +import re +import sys +import unittest +from typing import Dict, Set + +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from BackendBench.suite.model import load_models + + +class OpTracker: + """Track operators called during forward/backward passes using torch profiler.""" + + def __init__(self): + self.called_ops: Set[str] = set() + self.profiler = None + + def __enter__(self): + self.called_ops.clear() + + # Use torch profiler to track ops + self.profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + record_shapes=False, + with_stack=False, + ) + self.profiler.__enter__() + return self + + def __exit__(self, *args): + self.profiler.__exit__(*args) + + # Extract op names from profiler events + for event in self.profiler.events(): + event_name = event.name + # Look for aten operations + if "::" in event_name: + # Handle format like "aten::convolution" or "aten::convolution.default" + parts = event_name.replace("::", ".").split(".") + + if len(parts) >= 2 and parts[0] == "aten": + if len(parts) == 2: + # No variant specified, add .default + op_name = f"{parts[0]}.{parts[1]}.default" + else: + # Keep as is + op_name = event_name.replace("::", ".") + + self.called_ops.add(op_name) + + +class TestModelOpsCoverage(unittest.TestCase): + """Test that models invoke all operators declared in their configs.""" + + def test_all_models_ops_coverage(self): + """Test that all models invoke their declared forward and backward ops.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "BackendBench", + "suite", + "models", + ) + + models = load_models(models_dir=models_dir) + self.assertGreater(len(models), 0, "Should load at least one model") + + failures = [] + + for model_dict in models: + model_name = model_dict["name"] + model_class = model_dict["class"] + config = model_dict["config"] + + # Get expected ops from config + config_ops = config.get("ops", {}) + if isinstance(config_ops, list): + # Legacy format - skip or treat as forward-only + expected_forward = set(config_ops) + expected_backward = set() + else: + expected_forward = set(config_ops.get("forward", [])) + expected_backward = set(config_ops.get("backward", [])) + + # Skip if no ops to check + if not expected_forward and not expected_backward: + continue + + try: + # Initialize model + model_config = config.get("model_config", {}) + init_args = model_config.get("init_args", {}) + + if model_config.get("requires_init_seed"): + torch.manual_seed(42) + + model = model_class(**init_args) + + # Get a test input from model_tests + model_tests = config.get("model_tests", {}) + if not model_tests: + failures.append(f"{model_name}: No model_tests in config") + continue + + # Use first test case + test_name = list(model_tests.keys())[0] + test_args_str = model_tests[test_name] + + # Parse test args (simple eval for now) + # Format: "([], {'x': T([2, 3, 32, 32], f32)})" + test_input = self._create_test_input_from_string(test_args_str) + + # Track forward pass + tracker = OpTracker() + with tracker: + output = model(**test_input) + + forward_ops = tracker.called_ops + + # Check forward ops coverage + missing_forward = expected_forward - forward_ops + if missing_forward: + failures.append( + f"{model_name} [FORWARD]: Missing ops: {sorted(missing_forward)}" + ) + + # Track backward pass + if expected_backward: + # Ensure output requires grad + for param in model.parameters(): + param.requires_grad = True + + # Create loss + if isinstance(output, torch.Tensor): + loss = output.sum() + else: + # Handle tuple/dict outputs + loss = sum(v.sum() for v in output.values() if isinstance(v, torch.Tensor)) + + tracker_backward = OpTracker() + with tracker_backward: + loss.backward() + + backward_ops = tracker_backward.called_ops + + # Check backward ops coverage + missing_backward = expected_backward - backward_ops + if missing_backward: + failures.append( + f"{model_name} [BACKWARD]: Missing ops: {sorted(missing_backward)}" + ) + + except Exception as e: + failures.append(f"{model_name}: Error during test: {e}") + + # Report all failures at once + if failures: + error_msg = "\n\nOperator Coverage Failures:\n" + "\n".join( + f" - {failure}" for failure in failures + ) + self.fail(error_msg) + + def _create_test_input_from_string(self, test_args_str: str) -> Dict[str, torch.Tensor]: + """Parse test input string into actual tensors. + + Format: "([], {'x': T([2, 3, 32, 32], f32)})" + """ + + # Extract tensor specs: T([shape], dtype) + tensor_pattern = r"'(\w+)':\s*T\(\[([\d,\s]+)\],\s*(\w+)\)" + matches = re.findall(tensor_pattern, test_args_str) + + inputs = {} + for name, shape_str, dtype_str in matches: + shape = [int(x.strip()) for x in shape_str.split(",")] + + # Map dtype string to torch dtype + dtype_map = { + "f32": torch.float32, + "f64": torch.float64, + "i32": torch.int32, + "i64": torch.int64, + } + dtype = dtype_map.get(dtype_str, torch.float32) + + inputs[name] = torch.randn(*shape, dtype=dtype) + + return inputs + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_model_ops_filter.py b/test/test_model_ops_filter.py new file mode 100644 index 00000000..75f74874 --- /dev/null +++ b/test/test_model_ops_filter.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unit test to verify that ModelSuite's operator filter correctly matches +the operators defined in model configs. + +This test validates that: +1. load_models correctly loads model configs from the models directory +2. load_model_ops extracts the correct set of operators from model configs +3. TorchBenchTestSuite initialized with those operators has matching optests +4. JSON config files have proper format with required fields +""" + +import json +import os +import sys +import unittest +from typing import Any, Dict, List, Set + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from BackendBench.suite.model import load_models +from BackendBench.suite.torchbench import TorchBenchTestSuite + + +def load_model_ops(models: List[Dict[str, Any]]) -> Set[str]: + """Extract unique set of operators from model configs. + + Args: + models: List of model dictionaries with 'name', 'class', and 'config' keys + + Returns: + Set of operator names defined across all model configs + """ + model_ops = set() + for model in models: + config_ops = model["config"].get("ops") + if not config_ops: + raise ValueError(f"Model {model['name']} has no 'ops' field in config") + + # Support both list format (legacy) and dict format (forward/backward) + if isinstance(config_ops, list): + # Legacy format: ops is a flat list + ops_list = config_ops + elif isinstance(config_ops, dict): + # New format: ops is a dict with 'forward' and 'backward' keys + ops_list = [] + if "forward" in config_ops: + ops_list.extend(config_ops["forward"]) + if "backward" in config_ops: + ops_list.extend(config_ops["backward"]) + else: + raise ValueError( + f"Model {model['name']}: 'ops' must be either a list or a dict with 'forward'/'backward' keys" + ) + + if not ops_list: + raise ValueError(f"Model {model['name']}: 'ops' field is empty") + + model_ops.update(ops_list) + return model_ops + + +class TestModelOpsFilter(unittest.TestCase): + """Test that model ops filter correctly initializes TorchBenchTestSuite.""" + + def test_model_ops_match_suite_optests(self): + """Test that suite's optests match the operators from model configs.""" + # Get the models directory path (same as ModelSuite does) + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load models using load_models + models = load_models(models_dir=models_dir) + + # Verify we loaded at least one model + self.assertGreater(len(models), 0, "Should load at least one model") + + # Extract operators from model configs using load_model_ops + model_ops = load_model_ops(models) + + # Verify we have operators + self.assertGreater(len(model_ops), 0, "Should have at least one operator") + + # Create filter list from model ops + ops_filter = list(model_ops) + + # Initialize TorchBenchTestSuite with the filter + suite = TorchBenchTestSuite( + name="test_model_ops", + filename=None, # Use default HuggingFace dataset + filter=ops_filter, + topn=None, + ) + + # Get the set of operators in the suite's optests + suite_ops = set(suite.optests.keys()) + + # The suite_ops should be a subset of model_ops because: + # - model_ops is the filter we requested + # - suite_ops contains only those operators that exist in the TorchBench dataset + # - Not all operators in model configs may be in the dataset + self.assertTrue( + suite_ops.issubset(model_ops), + f"Suite operators {suite_ops} should be subset of model operators {model_ops}", + ) + + # Verify that suite actually has some operators + self.assertGreater( + len(suite_ops), 0, "Suite should contain at least one operator from model configs" + ) + + def test_load_model_ops_with_single_model(self): + """Test load_model_ops with a single model filter.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load a specific model (ToyCoreOpsModel if it exists) + all_models = load_models(models_dir=models_dir) + + # Find a model to test with + test_model_name = None + for model in all_models: + if "core" in model["name"].lower() or "toy" in model["name"].lower(): + test_model_name = model["name"] + break + + if test_model_name is None: + # Use first model if no matching name found + test_model_name = all_models[0]["name"] + + # Load just that one model + models = load_models(models_dir=models_dir, filter=[test_model_name]) + + self.assertEqual(len(models), 1, "Should load exactly one model") + + # Extract operators + model_ops = load_model_ops(models) + + # Verify the ops match what's in the config + config_ops = models[0]["config"]["ops"] + if isinstance(config_ops, list): + expected_ops = set(config_ops) + else: + # Dict format with forward/backward + expected_ops = set(config_ops.get("forward", []) + config_ops.get("backward", [])) + + self.assertEqual( + model_ops, + expected_ops, + f"Extracted ops {model_ops} should match config ops {expected_ops}", + ) + + # Create suite with these ops + ops_filter = list(model_ops) + suite = TorchBenchTestSuite( + name="test_single_model", + filename=None, + filter=ops_filter, + topn=None, + ) + + # Verify suite has operators (subset of model_ops) + suite_ops = set(suite.optests.keys()) + self.assertTrue( + suite_ops.issubset(model_ops), + f"Suite ops {suite_ops} should be subset of model ops {model_ops}", + ) + + def test_load_model_ops_combines_multiple_models(self): + """Test that load_model_ops correctly combines ops from multiple models.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load all models + models = load_models(models_dir=models_dir) + + # Skip if only one model + if len(models) <= 1: + self.skipTest("Need multiple models to test combination") + + # Extract ops from all models + all_ops = load_model_ops(models) + + # Extract ops from each model individually + individual_ops = [] + for model in models: + config_ops = model["config"]["ops"] + if isinstance(config_ops, list): + individual_ops.extend(config_ops) + else: + # Dict format with forward/backward + individual_ops.extend(config_ops.get("forward", [])) + individual_ops.extend(config_ops.get("backward", [])) + + # The combined set should equal the union of all individual ops + expected_ops = set(individual_ops) + self.assertEqual(all_ops, expected_ops, "Combined ops should equal union of all model ops") + + def test_json_configs_have_required_fields(self): + """Test that all JSON config files have proper format with required fields.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load all models + models = load_models(models_dir=models_dir) + + for model in models: + model_name = model["name"] + config = model["config"] + + # Check required top-level fields + self.assertIn("ops", config, f"Model {model_name}: config must have 'ops' field") + self.assertIn( + "model_tests", config, f"Model {model_name}: config must have 'model_tests' field" + ) + + # Validate 'ops' field - can be list or dict + config_ops = config["ops"] + if isinstance(config_ops, list): + # Legacy format: flat list + self.assertGreater( + len(config_ops), 0, f"Model {model_name}: 'ops' list must not be empty" + ) + for op in config_ops: + self.assertIsInstance( + op, str, f"Model {model_name}: each op in 'ops' must be a string" + ) + elif isinstance(config_ops, dict): + # New format: dict with forward/backward + self.assertTrue( + "forward" in config_ops or "backward" in config_ops, + f"Model {model_name}: 'ops' dict must contain 'forward' or 'backward' keys", + ) + if "forward" in config_ops: + self.assertIsInstance( + config_ops["forward"], + list, + f"Model {model_name}: 'ops.forward' must be a list", + ) + for op in config_ops["forward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.forward' must be a string", + ) + if "backward" in config_ops: + self.assertIsInstance( + config_ops["backward"], + list, + f"Model {model_name}: 'ops.backward' must be a list", + ) + for op in config_ops["backward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.backward' must be a string", + ) + else: + self.fail( + f"Model {model_name}: 'ops' must be either a list or a dict with 'forward'/'backward' keys" + ) + + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + for test_name, test_args in config["model_tests"].items(): + self.assertIsInstance( + test_name, str, f"Model {model_name}: test names must be strings" + ) + self.assertIsInstance( + test_args, str, f"Model {model_name}: test args must be strings" + ) + + # Check optional but recommended fields + if "model_config" in config: + self.assertIsInstance( + config["model_config"], + dict, + f"Model {model_name}: 'model_config' must be a dictionary if present", + ) + + def test_json_files_are_valid_json(self): + """Test that all JSON config files are valid JSON and can be parsed.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Find all JSON files in the models directory + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + json_file = os.path.join(model_dir, f"{model_name}.json") + if not os.path.exists(json_file): + continue + + # Try to parse the JSON file + with open(json_file, "r") as f: + try: + config = json.load(f) + self.assertIsInstance( + config, + dict, + f"JSON file {json_file} must contain a dictionary at top level", + ) + except json.JSONDecodeError as e: + self.fail(f"JSON file {json_file} is not valid JSON: {e}") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_model_suite_correctness.py b/test/test_model_suite_correctness.py index d5b3e37f..655b4ce1 100644 --- a/test/test_model_suite_correctness.py +++ b/test/test_model_suite_correctness.py @@ -20,7 +20,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) -from BackendBench.suite.model import load_toy_models, ModelSuite +from BackendBench.suite.model import load_models, ModelSuite # Setup logging logging.basicConfig(level=logging.WARNING) @@ -31,7 +31,7 @@ class TestModelLoading(unittest.TestCase): def test_load_models(self): """Test that models can be loaded from directory.""" - models = load_toy_models(toy_models_dir="BackendBench/suite/models") + models = load_models(models_dir="BackendBench/suite/models") self.assertGreater(len(models), 0, "Should load at least one model") # Verify model structure @@ -42,16 +42,14 @@ def test_load_models(self): def test_load_specific_model(self): """Test loading a specific model by name.""" - models = load_toy_models( - toy_models_dir="BackendBench/suite/models", filter=["toy_core_ops"] - ) + models = load_models(models_dir="BackendBench/suite/models", filter=["ToyCoreOpsModel"]) self.assertEqual(len(models), 1) - self.assertEqual(models[0]["name"], "toy_core_ops") + self.assertEqual(models[0]["name"], "ToyCoreOpsModel") def test_invalid_filter(self): """Test that invalid filter raises error.""" with self.assertRaises(ValueError): - load_toy_models(toy_models_dir="BackendBench/suite/models", filter=["nonexistent"]) + load_models(models_dir="BackendBench/suite/models", filter=["nonexistent"]) class TestModelSuite(unittest.TestCase): diff --git a/test/test_model_suite_integration.py b/test/test_model_suite_integration.py index 87bb6422..58c33a8b 100644 --- a/test/test_model_suite_integration.py +++ b/test/test_model_suite_integration.py @@ -50,9 +50,11 @@ def test_complete_workflow(self): ) self.assertEqual(result.returncode, 0, "CLI should succeed") - self.assertIn("FULL MODEL TESTING", result.stdout) - self.assertIn("Model Suite Score", result.stdout) - self.assertIn("toy_core_ops", result.stdout) + # Check that output contains expected content + self.assertTrue( + "correctness score" in result.stdout.lower() or "Model Suite Score" in result.stdout, + "Should contain correctness score", + ) def test_filtering_by_model(self): """Test filtering models by name.""" @@ -68,7 +70,7 @@ def test_filtering_by_model(self): "--ops-directory", "generated_kernels", "--model-filter", - "toy_core_ops", + "ToyCoreOpsModel", "--disable-output-logs", "--log-level", "ERROR", @@ -79,7 +81,8 @@ def test_filtering_by_model(self): ) self.assertEqual(result.returncode, 0, "Filtered run should succeed") - self.assertIn("toy_core_ops", result.stdout) + # Check that filtering worked + self.assertTrue(len(result.stdout) > 0, "Should have output") def test_invalid_backend_error(self): """Test that model suite rejects invalid backends.""" @@ -130,6 +133,10 @@ def test_empty_filter(self): # Should fail because explicitly requested model not found self.assertNotEqual(result.returncode, 0, "Should fail with nonexistent filter") + self.assertTrue( + "no models found" in result.stderr.lower() or "valueerror" in result.stderr.lower(), + f"Should have error about missing models. stderr: {result.stderr}", + ) def test_ops_filter_rejected(self): """Test that --ops filter is rejected for model suite.""" @@ -170,9 +177,9 @@ def test_initialization_variants(self): self.assertGreater(len(suite1.models), 0, "Should load models by default") # With filter - suite2 = ModelSuite(filter=["toy_core_ops"]) + suite2 = ModelSuite(filter=["ToyCoreOpsModel"]) self.assertEqual(len(suite2.models), 1, "Should load exactly 1 model") - self.assertEqual(suite2.models[0]["name"], "toy_core_ops") + self.assertEqual(suite2.models[0]["name"], "ToyCoreOpsModel") # Empty filter - should raise error with self.assertRaises(ValueError) as context: @@ -184,24 +191,20 @@ def test_operator_level_integration(self): suite = ModelSuite() op_tests = list(suite) - # Model suite currently returns empty iterator - # Operator extraction from model tracing is not yet implemented - # The suite focuses on full model testing via test_model_correctness() - self.assertEqual(len(op_tests), 0, "Operator extraction not yet implemented") + # Model suite returns operator tests from TorchBench filtered by model ops + self.assertGreater(len(op_tests), 0, "Should have operator tests") def test_model_level_integration(self): """Test that model-level testing works.""" suite = ModelSuite() - results = suite.test_model_correctness() + # ModelSuite stores models for evaluation + self.assertGreater(len(suite.models), 0, "Should have models") - self.assertIsInstance(results, dict) - self.assertGreater(len(results), 0, "Should have results") - - # Verify structure - for model_name, config_results in results.items(): - self.assertIsInstance(config_results, dict) - for config_name, is_correct in config_results.items(): - self.assertIsInstance(is_correct, bool) + # Verify model structure + for model in suite.models: + self.assertIn("name", model) + self.assertIn("class", model) + self.assertIn("config", model) def test_output_format(self): """Test that CLI output is properly formatted.""" @@ -230,17 +233,8 @@ def test_output_format(self): # Check for expected sections self.assertIn("correctness score", output.lower()) self.assertIn("performance score", output.lower()) - self.assertIn("FULL MODEL TESTING", output) - self.assertIn("Model Correctness Results:", output) - self.assertIn("Model Suite Score:", output) - - # Check for formatting - self.assertIn("=" * 80, output) - self.assertIn("-" * 80, output) - - # Check for pass/fail indicators - has_pass_fail = "✓ PASS" in output or "✗ FAIL" in output - self.assertTrue(has_pass_fail, "Should show pass/fail indicators") + # Model suite outputs operator-level scores + self.assertTrue(len(output) > 0, "Should have output") if __name__ == "__main__": From 1791b8f1fbcad735dfbf467988ab0e3028ab972d Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 03:02:02 +0000 Subject: [PATCH 08/16] edits --- ...ps_filter.py => test_model_ops_configs.py} | 94 +------ test/test_model_ops_coverage.py | 12 +- ...ite_correctness.py => test_model_suite.py} | 36 +-- test/test_model_suite_integration.py | 241 ------------------ 4 files changed, 5 insertions(+), 378 deletions(-) rename test/{test_model_ops_filter.py => test_model_ops_configs.py} (73%) rename test/{test_model_suite_correctness.py => test_model_suite.py} (55%) delete mode 100644 test/test_model_suite_integration.py diff --git a/test/test_model_ops_filter.py b/test/test_model_ops_configs.py similarity index 73% rename from test/test_model_ops_filter.py rename to test/test_model_ops_configs.py index 75f74874..c5d6915f 100644 --- a/test/test_model_ops_filter.py +++ b/test/test_model_ops_configs.py @@ -17,12 +17,9 @@ import json import os -import sys import unittest from typing import Any, Dict, List, Set -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - from BackendBench.suite.model import load_models from BackendBench.suite.torchbench import TorchBenchTestSuite @@ -65,7 +62,7 @@ def load_model_ops(models: List[Dict[str, Any]]) -> Set[str]: return model_ops -class TestModelOpsFilter(unittest.TestCase): +class TestModelOpsConfigs(unittest.TestCase): """Test that model ops filter correctly initializes TorchBenchTestSuite.""" def test_model_ops_match_suite_optests(self): @@ -115,95 +112,6 @@ def test_model_ops_match_suite_optests(self): len(suite_ops), 0, "Suite should contain at least one operator from model configs" ) - def test_load_model_ops_with_single_model(self): - """Test load_model_ops with a single model filter.""" - models_dir = os.path.join( - os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" - ) - - # Load a specific model (ToyCoreOpsModel if it exists) - all_models = load_models(models_dir=models_dir) - - # Find a model to test with - test_model_name = None - for model in all_models: - if "core" in model["name"].lower() or "toy" in model["name"].lower(): - test_model_name = model["name"] - break - - if test_model_name is None: - # Use first model if no matching name found - test_model_name = all_models[0]["name"] - - # Load just that one model - models = load_models(models_dir=models_dir, filter=[test_model_name]) - - self.assertEqual(len(models), 1, "Should load exactly one model") - - # Extract operators - model_ops = load_model_ops(models) - - # Verify the ops match what's in the config - config_ops = models[0]["config"]["ops"] - if isinstance(config_ops, list): - expected_ops = set(config_ops) - else: - # Dict format with forward/backward - expected_ops = set(config_ops.get("forward", []) + config_ops.get("backward", [])) - - self.assertEqual( - model_ops, - expected_ops, - f"Extracted ops {model_ops} should match config ops {expected_ops}", - ) - - # Create suite with these ops - ops_filter = list(model_ops) - suite = TorchBenchTestSuite( - name="test_single_model", - filename=None, - filter=ops_filter, - topn=None, - ) - - # Verify suite has operators (subset of model_ops) - suite_ops = set(suite.optests.keys()) - self.assertTrue( - suite_ops.issubset(model_ops), - f"Suite ops {suite_ops} should be subset of model ops {model_ops}", - ) - - def test_load_model_ops_combines_multiple_models(self): - """Test that load_model_ops correctly combines ops from multiple models.""" - models_dir = os.path.join( - os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" - ) - - # Load all models - models = load_models(models_dir=models_dir) - - # Skip if only one model - if len(models) <= 1: - self.skipTest("Need multiple models to test combination") - - # Extract ops from all models - all_ops = load_model_ops(models) - - # Extract ops from each model individually - individual_ops = [] - for model in models: - config_ops = model["config"]["ops"] - if isinstance(config_ops, list): - individual_ops.extend(config_ops) - else: - # Dict format with forward/backward - individual_ops.extend(config_ops.get("forward", [])) - individual_ops.extend(config_ops.get("backward", [])) - - # The combined set should equal the union of all individual ops - expected_ops = set(individual_ops) - self.assertEqual(all_ops, expected_ops, "Combined ops should equal union of all model ops") - def test_json_configs_have_required_fields(self): """Test that all JSON config files have proper format with required fields.""" models_dir = os.path.join( diff --git a/test/test_model_ops_coverage.py b/test/test_model_ops_coverage.py index 6b7337ac..99d8301b 100644 --- a/test/test_model_ops_coverage.py +++ b/test/test_model_ops_coverage.py @@ -15,14 +15,11 @@ import os import re -import sys import unittest from typing import Dict, Set import torch -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - from BackendBench.suite.model import load_models @@ -91,13 +88,8 @@ def test_all_models_ops_coverage(self): # Get expected ops from config config_ops = config.get("ops", {}) - if isinstance(config_ops, list): - # Legacy format - skip or treat as forward-only - expected_forward = set(config_ops) - expected_backward = set() - else: - expected_forward = set(config_ops.get("forward", [])) - expected_backward = set(config_ops.get("backward", [])) + expected_forward = set(config_ops.get("forward", [])) + expected_backward = set(config_ops.get("backward", [])) # Skip if no ops to check if not expected_forward and not expected_backward: diff --git a/test/test_model_suite_correctness.py b/test/test_model_suite.py similarity index 55% rename from test/test_model_suite_correctness.py rename to test/test_model_suite.py index 655b4ce1..12ddaf1b 100644 --- a/test/test_model_suite_correctness.py +++ b/test/test_model_suite.py @@ -14,20 +14,16 @@ """ import logging -import os -import sys import unittest -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - -from BackendBench.suite.model import load_models, ModelSuite +from BackendBench.suite.model import load_models # Setup logging logging.basicConfig(level=logging.WARNING) class TestModelLoading(unittest.TestCase): - """Test toy model loading functionality.""" + """Test model loading functionality.""" def test_load_models(self): """Test that models can be loaded from directory.""" @@ -52,34 +48,6 @@ def test_invalid_filter(self): load_models(models_dir="BackendBench/suite/models", filter=["nonexistent"]) -class TestModelSuite(unittest.TestCase): - """Test ModelSuite integration with TorchBench.""" - - def test_suite_initialization(self): - """Test that ModelSuite can be initialized.""" - suite = ModelSuite() - self.assertEqual(suite.name, "model") - self.assertIsNotNone(suite.optests) - - def test_suite_has_operators(self): - """Test that suite extracts operators from models.""" - suite = ModelSuite() - # Should have extracted and filtered operators - self.assertGreater(len(suite.optests), 0, "Should have at least one operator") - - def test_suite_iteration(self): - """Test that suite can be iterated (TorchBench interface).""" - suite = ModelSuite() - op_tests = list(suite) - # Should have at least one operator test - self.assertGreater(len(op_tests), 0, "Should have at least one operator test") - - def test_empty_filter(self): - """Test suite raises error for nonexistent model.""" - with self.assertRaises(ValueError): - _ = ModelSuite(filter=["nonexistent_model"]) - - if __name__ == "__main__": # Run tests unittest.main(verbosity=2) diff --git a/test/test_model_suite_integration.py b/test/test_model_suite_integration.py deleted file mode 100644 index 58c33a8b..00000000 --- a/test/test_model_suite_integration.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -""" -Essential integration tests for Model Suite PR #3: Final Polish - -This test suite validates: -1. Complete CLI workflow with model suite -2. Filtering functionality -3. Error handling for invalid backends -4. Operator-level and model-level testing integration -""" - -import os -import subprocess -import sys -import unittest - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - -from BackendBench.suite.model import ModelSuite - - -class TestModelSuiteCLI(unittest.TestCase): - """Test CLI integration for model suite.""" - - def test_complete_workflow(self): - """Test complete workflow from CLI.""" - result = subprocess.run( - [ - "python", - "-m", - "BackendBench.scripts.main", - "--suite", - "model", - "--backend", - "directory", - "--ops-directory", - "generated_kernels", - "--disable-output-logs", - "--log-level", - "ERROR", - ], - capture_output=True, - text=True, - timeout=60, - ) - - self.assertEqual(result.returncode, 0, "CLI should succeed") - # Check that output contains expected content - self.assertTrue( - "correctness score" in result.stdout.lower() or "Model Suite Score" in result.stdout, - "Should contain correctness score", - ) - - def test_filtering_by_model(self): - """Test filtering models by name.""" - result = subprocess.run( - [ - "python", - "-m", - "BackendBench.scripts.main", - "--suite", - "model", - "--backend", - "directory", - "--ops-directory", - "generated_kernels", - "--model-filter", - "ToyCoreOpsModel", - "--disable-output-logs", - "--log-level", - "ERROR", - ], - capture_output=True, - text=True, - timeout=60, - ) - - self.assertEqual(result.returncode, 0, "Filtered run should succeed") - # Check that filtering worked - self.assertTrue(len(result.stdout) > 0, "Should have output") - - def test_invalid_backend_error(self): - """Test that model suite rejects invalid backends.""" - result = subprocess.run( - [ - "python", - "-m", - "BackendBench.scripts.main", - "--suite", - "model", - "--backend", - "aten", - "--disable-output-logs", - "--log-level", - "ERROR", - ], - capture_output=True, - text=True, - timeout=30, - ) - - self.assertNotEqual(result.returncode, 0, "Should fail with invalid backend") - self.assertIn("model suite only supports directory backend", result.stderr.lower()) - - def test_empty_filter(self): - """Test handling of nonexistent model filter.""" - result = subprocess.run( - [ - "python", - "-m", - "BackendBench.scripts.main", - "--suite", - "model", - "--backend", - "directory", - "--ops-directory", - "generated_kernels", - "--model-filter", - "nonexistent_model", - "--disable-output-logs", - "--log-level", - "ERROR", - ], - capture_output=True, - text=True, - timeout=60, - ) - - # Should fail because explicitly requested model not found - self.assertNotEqual(result.returncode, 0, "Should fail with nonexistent filter") - self.assertTrue( - "no models found" in result.stderr.lower() or "valueerror" in result.stderr.lower(), - f"Should have error about missing models. stderr: {result.stderr}", - ) - - def test_ops_filter_rejected(self): - """Test that --ops filter is rejected for model suite.""" - result = subprocess.run( - [ - "python", - "-m", - "BackendBench.scripts.main", - "--suite", - "model", - "--backend", - "directory", - "--ops-directory", - "generated_kernels", - "--ops", - "toy_core_ops", - "--disable-output-logs", - "--log-level", - "ERROR", - ], - capture_output=True, - text=True, - timeout=30, - ) - - # Should fail with error message about --ops not supported - self.assertNotEqual(result.returncode, 0, "Should fail with --ops") - self.assertIn("--ops filter is not supported for model suite", result.stderr) - - -class TestModelSuiteIntegration(unittest.TestCase): - """Test ModelSuite integration and initialization.""" - - def test_initialization_variants(self): - """Test ModelSuite initialization with various options.""" - # Default initialization - suite1 = ModelSuite() - self.assertGreater(len(suite1.models), 0, "Should load models by default") - - # With filter - suite2 = ModelSuite(filter=["ToyCoreOpsModel"]) - self.assertEqual(len(suite2.models), 1, "Should load exactly 1 model") - self.assertEqual(suite2.models[0]["name"], "ToyCoreOpsModel") - - # Empty filter - should raise error - with self.assertRaises(ValueError) as context: - _ = ModelSuite(filter=["nonexistent"]) - self.assertIn("No models found", str(context.exception)) - - def test_operator_level_integration(self): - """Test that operator-level testing works via __iter__.""" - suite = ModelSuite() - op_tests = list(suite) - - # Model suite returns operator tests from TorchBench filtered by model ops - self.assertGreater(len(op_tests), 0, "Should have operator tests") - - def test_model_level_integration(self): - """Test that model-level testing works.""" - suite = ModelSuite() - # ModelSuite stores models for evaluation - self.assertGreater(len(suite.models), 0, "Should have models") - - # Verify model structure - for model in suite.models: - self.assertIn("name", model) - self.assertIn("class", model) - self.assertIn("config", model) - - def test_output_format(self): - """Test that CLI output is properly formatted.""" - result = subprocess.run( - [ - "python", - "-m", - "BackendBench.scripts.main", - "--suite", - "model", - "--backend", - "directory", - "--ops-directory", - "generated_kernels", - "--disable-output-logs", - "--log-level", - "ERROR", - ], - capture_output=True, - text=True, - timeout=60, - ) - - output = result.stdout - - # Check for expected sections - self.assertIn("correctness score", output.lower()) - self.assertIn("performance score", output.lower()) - # Model suite outputs operator-level scores - self.assertTrue(len(output) > 0, "Should have output") - - -if __name__ == "__main__": - unittest.main(verbosity=2) From b06fa377b3b2bf0614cb2a12d1663521f42f3b25 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 03:47:56 +0000 Subject: [PATCH 09/16] edits --- BackendBench/eval_model.py | 13 +- BackendBench/scripts/main.py | 6 +- .../models/SmokeTestModel/SmokeTestModel.json | 25 ++++ .../models/SmokeTestModel/SmokeTestModel.py | 138 ++++++++++++++++++ .../ToyCoreOpsModel/ToyCoreOpsModel.json | 1 - .../models/ToyCoreOpsModel/ToyCoreOpsModel.py | 25 ---- 6 files changed, 171 insertions(+), 37 deletions(-) create mode 100644 BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json create mode 100644 BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py index 0ea87896..12b6c6d2 100644 --- a/BackendBench/eval_model.py +++ b/BackendBench/eval_model.py @@ -7,6 +7,7 @@ """Model-level evaluation utilities for testing full model correctness.""" import logging +import random import traceback from dataclasses import dataclass from typing import Any, Dict, List, Tuple @@ -233,16 +234,8 @@ def _run_model( # Extract model initialization args init_args = model_config.get("init_args", {}).copy() - # Handle seed: use runtime_seed if required, otherwise use seed from init_args - if model_config.get("requires_init_seed", False): - # Use the generated runtime seed - seed = model_config["runtime_seed"] - init_args["seed"] = seed - else: - # Use seed from init_args or default - seed = init_args.get("seed", 42) - - # Set seed for deterministic behavior + # Generate seed dynamically and set for deterministic behavior + seed = random.randint(0, 2**32 - 1) torch.manual_seed(seed) # Create fresh model instance diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 1bfb368b..19803bde 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -212,9 +212,13 @@ def cli( torch.bfloat16, filter=ops, ), - "model": lambda: ModelSuite(filter=model_filter), + "model": lambda: ModelSuite(filter=model_filter, topn=topn_inputs), }[suite]() + # model suite only supports directory backend + if suite == "model" and backend != "directory": + raise ValueError("model suite only supports directory backend") + backend_name = backend if backend == "llm-relay": llm_client = LLMRelayKernelGenerator(model=llm_model) diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json new file mode 100644 index 00000000..ba1589ca --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json @@ -0,0 +1,25 @@ +{ + "model_config": { + "init_args": { + "input_dim": 128, + "hidden_dim": 256, + "output_dim": 64 + } + }, + "ops": { + "forward": [ + "aten.mm.default" + ], + "backward": [ + "aten.mm.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 128], f32)})", + "medium_batch": "([], {'x': T([4, 128], f32)})", + "large_batch": "([], {'x': T([8, 128], f32)})" + }, + "metadata": { + "description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes" + } +} diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py new file mode 100644 index 00000000..09915f21 --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Smoke test model focused on matrix multiplication operations. + +This model is designed to test mm operations in both forward and backward passes. +It uses explicit torch.mm calls to ensure matrix multiplication operations are triggered. + +The model implements a simple architecture with: +1. Matrix multiplication operations +2. ReLU activations +3. Element-wise operations + +Usage: + python SmokeTestModel.py + +This will create a model with default configuration and run a simple forward/backward pass +to demonstrate that mm operations are used. +""" + +import torch +import torch.nn as nn + + +class SmokeTestModel(nn.Module): + """ + Simple model focused on testing matrix multiplication operations. + + This model uses explicit torch.mm operations to ensure we trigger + aten.mm.default in both forward and backward passes. + """ + + def __init__( + self, + input_dim: int = 128, + hidden_dim: int = 256, + output_dim: int = 64, + ): + """ + Initialize the SmokeTestModel. + + Args: + input_dim: Input feature dimension + hidden_dim: Hidden layer dimension + output_dim: Output feature dimension + """ + super().__init__() + + # Store configuration + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + # Weight matrices for explicit mm operations + self.weight1 = nn.Parameter(torch.randn(input_dim, hidden_dim)) + self.weight2 = nn.Parameter(torch.randn(hidden_dim, output_dim)) + + # Bias terms + self.bias1 = nn.Parameter(torch.randn(hidden_dim)) + self.bias2 = nn.Parameter(torch.randn(output_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass using explicit mm operations. + + Args: + x: Input tensor of shape (batch_size, input_dim) + + Returns: + Output tensor of shape (batch_size, output_dim) + """ + # First mm operation: x @ weight1 + # This triggers aten.mm.default in forward + x = torch.mm(x, self.weight1) + x = x + self.bias1 + x = torch.relu(x) + + # Second mm operation: x @ weight2 + # This triggers aten.mm.default again in forward + x = torch.mm(x, self.weight2) + x = x + self.bias2 + + return x + + +def main(): + """ + Demonstrate the SmokeTestModel with a simple forward/backward pass. + """ + print("SmokeTestModel Demonstration") + print("=" * 50) + + # Create model with default configuration + model = SmokeTestModel( + input_dim=128, + hidden_dim=256, + output_dim=64, + ) + + # Create sample input + batch_size = 4 + input_tensor = torch.randn(batch_size, 128, requires_grad=True) + + print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters") + print(f"Input shape: {input_tensor.shape}") + + # Forward pass + model.train() + output = model(input_tensor) + expected_shape = torch.Size([batch_size, 64]) + + print(f"Output shape: {output.shape}") + print(f"Expected shape: {expected_shape}") + print(f"Shape matches: {output.shape == expected_shape}") + + # Perform backward pass to trigger mm operations in backward + print("\nPerforming backward pass...") + loss = output.sum() + loss.backward() + print("✓ Backward pass completed successfully") + + # Check gradients were computed + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f"✓ Gradients computed for {grad_count}/{total_params} parameters") + + print("\n✓ Model demonstration completed successfully!") + print("This model is ready to be used with the Model Suite for testing mm operators.") + + return model + + +if __name__ == "__main__": + main() diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json index 0b5e4c9a..1586273e 100644 --- a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json @@ -1,6 +1,5 @@ { "model_config": { - "requires_init_seed": true, "init_args": { "in_channels": 3, "hidden_channels": 32, diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py index a6ec5531..42d84f71 100644 --- a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py @@ -56,7 +56,6 @@ def __init__( hidden_channels: int = 32, out_channels: int = 8, num_groups: int = 8, - seed: int = 42, ): """ Initialize the ToyCoreOpsModel. @@ -66,7 +65,6 @@ def __init__( hidden_channels: Number of hidden channels in conv layers out_channels: Number of output channels num_groups: Number of groups for GroupNorm (must divide hidden_channels) - seed: Random seed for deterministic weight initialization Raises: ValueError: If hidden_channels is not divisible by num_groups @@ -107,28 +105,6 @@ def __init__( in_channels=hidden_channels, out_channels=out_channels, kernel_size=1 ) - # Initialize weights deterministically - self._initialize_weights(seed) - - def _initialize_weights(self, seed: int): - """ - Initialize model weights deterministically using the given seed. - - Args: - seed: Random seed for reproducible initialization - """ - # Set random seed for deterministic initialization - torch.manual_seed(seed) - - for module in self.modules(): - if isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") - if module.bias is not None: - nn.init.constant_(module.bias, 0) - elif isinstance(module, nn.GroupNorm): - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) - def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass that sets up the computational graph to trigger all target backward operators. @@ -185,7 +161,6 @@ def main(): hidden_channels=32, out_channels=8, num_groups=8, - seed=42, # Deterministic initialization ) # Create sample input From cebd34cf8d187a507769b1bb44a9e0f789bcf671 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 04:40:00 +0000 Subject: [PATCH 10/16] edits --- BackendBench/eval_model.py | 9 +++++++-- BackendBench/scripts/main.py | 9 +++------ BackendBench/suite/model.py | 8 +++----- .../suite/models/SmokeTestModel/SmokeTestModel.json | 8 ++++---- .../suite/models/SmokeTestModel/SmokeTestModel.py | 12 +++++------- 5 files changed, 22 insertions(+), 24 deletions(-) diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py index 12b6c6d2..136ee062 100644 --- a/BackendBench/eval_model.py +++ b/BackendBench/eval_model.py @@ -42,8 +42,8 @@ def eval_model_correctness_test( test_name: str, test_args: str, kernel_dir: str = None, - atol: float = 1e-6, - rtol: float = 1e-5, + atol: float = 1e-2, + rtol: float = 1e-2, ) -> ModelCorrectnessTestResult: """Evaluate model correctness by comparing eager vs backend execution. @@ -73,6 +73,11 @@ def eval_model_correctness_test( model_class, model_config, test_args, backend_enabled=True, kernel_dir=kernel_dir ) + # print out the max diff between eager_out and backend_out + print( + f"Max diff between eager_out and backend_out: {torch.max(torch.abs(eager_out - backend_out))}" + ) + # Compare outputs output_match = torch.allclose(eager_out, backend_out, atol=atol, rtol=rtol) diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 19803bde..8e3681c6 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -169,8 +169,6 @@ def cli( p, ): if suite != "torchbench": - if topn_inputs is not None: - raise ValueError("topn-inputs is only supported for torchbench suite") if check_overhead_dominated_ops: raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite") @@ -185,6 +183,9 @@ def cli( if suite != "model" and model_filter is not None: raise ValueError("--model-filter is only supported for model suite") + if (suite != "torchbench" and suite != "model") and topn_inputs is not None: + raise ValueError("topn-inputs is only supported for torchbench suite") + setup_logging(log_level) if ops: ops = ops.split(",") @@ -215,10 +216,6 @@ def cli( "model": lambda: ModelSuite(filter=model_filter, topn=topn_inputs), }[suite]() - # model suite only supports directory backend - if suite == "model" and backend != "directory": - raise ValueError("model suite only supports directory backend") - backend_name = backend if backend == "llm-relay": llm_client = LLMRelayKernelGenerator(model=llm_model) diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index 450c2ce4..fcd9f592 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -223,6 +223,9 @@ def print_results(self, results: Dict[str, Any]) -> None: Args: results: Dictionary with evaluation results from eval_model """ + + print(results) + model_name = results.get("model_name", "Unknown") passed = results.get("passed", False) num_passed = results.get("num_passed", 0) @@ -233,11 +236,6 @@ def print_results(self, results: Dict[str, Any]) -> None: f"Status: {'✓ Passed' if passed else '✗ Failed'} ({num_passed}/{num_total} tests)" ) - # If there's a top-level error (no tests run) - if "error" in results: - logger.info(f"Error: {results['error']}") - return - # Print details for each test test_results = results.get("test_results", []) for result in test_results: diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json index ba1589ca..b7d286ae 100644 --- a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json @@ -2,8 +2,8 @@ "model_config": { "init_args": { "input_dim": 128, - "hidden_dim": 256, - "output_dim": 64 + "hidden_dim": 128, + "output_dim": 128 } }, "ops": { @@ -16,8 +16,8 @@ }, "model_tests": { "small_batch": "([], {'x': T([2, 128], f32)})", - "medium_batch": "([], {'x': T([4, 128], f32)})", - "large_batch": "([], {'x': T([8, 128], f32)})" + "medium_batch": "([], {'x': T([16, 128], f32)})", + "large_batch": "([], {'x': T([32, 128], f32)})" }, "metadata": { "description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes" diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py index 09915f21..ab12a129 100644 --- a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py @@ -37,8 +37,8 @@ class SmokeTestModel(nn.Module): def __init__( self, input_dim: int = 128, - hidden_dim: int = 256, - output_dim: int = 64, + hidden_dim: int = 128, + output_dim: int = 128, ): """ Initialize the SmokeTestModel. @@ -97,8 +97,8 @@ def main(): # Create model with default configuration model = SmokeTestModel( input_dim=128, - hidden_dim=256, - output_dim=64, + hidden_dim=128, + output_dim=128, ) # Create sample input @@ -111,7 +111,7 @@ def main(): # Forward pass model.train() output = model(input_tensor) - expected_shape = torch.Size([batch_size, 64]) + expected_shape = torch.Size([batch_size, 128]) print(f"Output shape: {output.shape}") print(f"Expected shape: {expected_shape}") @@ -131,8 +131,6 @@ def main(): print("\n✓ Model demonstration completed successfully!") print("This model is ready to be used with the Model Suite for testing mm operators.") - return model - if __name__ == "__main__": main() From b0b224b6456d0a8e1e809011515e560ed5032933 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 05:13:48 +0000 Subject: [PATCH 11/16] edits --- BackendBench/eval_model.py | 47 +++++++++++++++++++++----------- BackendBench/scripts/main.py | 11 +++++--- BackendBench/suite/model.py | 28 ++----------------- BackendBench/suite/torchbench.py | 4 +++ 4 files changed, 45 insertions(+), 45 deletions(-) diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py index 136ee062..7bb5a587 100644 --- a/BackendBench/eval_model.py +++ b/BackendBench/eval_model.py @@ -15,6 +15,7 @@ import torch import BackendBench +from BackendBench.eval import allclose from BackendBench.utils import deserialize_args logger = logging.getLogger(__name__) @@ -56,30 +57,39 @@ def eval_model_correctness_test( test_name: Name of this test configuration test_args: Serialized arguments string for forward pass kernel_dir: Optional directory containing kernels for backend - atol: Absolute tolerance for torch.allclose - rtol: Relative tolerance for torch.allclose + atol: Absolute tolerance for allclose + rtol: Relative tolerance for allclose Returns: ModelCorrectnessTestResult with detailed comparison results """ try: + # Generate a single seed to use for both eager and backend runs + # This ensures both runs use the same model initialization + seed = random.randint(0, 2**32 - 1) + # Run in eager mode (reference) eager_out, eager_grads = _run_model( - model_class, model_config, test_args, backend_enabled=False, kernel_dir=kernel_dir + model_class, + model_config, + test_args, + backend_enabled=False, + kernel_dir=kernel_dir, + seed=seed, ) # Run with backend (implementation) backend_out, backend_grads = _run_model( - model_class, model_config, test_args, backend_enabled=True, kernel_dir=kernel_dir - ) - - # print out the max diff between eager_out and backend_out - print( - f"Max diff between eager_out and backend_out: {torch.max(torch.abs(eager_out - backend_out))}" + model_class, + model_config, + test_args, + backend_enabled=True, + kernel_dir=kernel_dir, + seed=seed, ) # Compare outputs - output_match = torch.allclose(eager_out, backend_out, atol=atol, rtol=rtol) + output_match = allclose(eager_out, backend_out, atol=atol, rtol=rtol) # Compare gradients gradients_match = True @@ -87,7 +97,7 @@ def eval_model_correctness_test( gradients_match = False else: for eager_grad, backend_grad in zip(eager_grads, backend_grads): - if not torch.allclose(eager_grad, backend_grad, atol=atol, rtol=rtol): + if not allclose(eager_grad, backend_grad, atol=atol, rtol=rtol): gradients_match = False break @@ -217,6 +227,7 @@ def _run_model( test_args: str, backend_enabled: bool, kernel_dir: str = "generated_kernels", + seed: int = None, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Run model with or without backend enabled. @@ -226,6 +237,7 @@ def _run_model( test_args: Serialized arguments string for forward pass backend_enabled: If True, use BackendBench context manager kernel_dir: Optional directory containing kernels + seed: Random seed for reproducibility. If None, generates a random seed. Returns: Tuple of (output, gradients) where: @@ -233,16 +245,19 @@ def _run_model( - gradients: List of gradient tensors [input_grad, param1_grad, ...] """ - # Deserialize test arguments + # Generate seed dynamically and set for deterministic behavior + # IMPORTANT: Must set seed BEFORE deserializing args, because deserialization + # may create random tensors! + if seed is None: + seed = random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + + # Deserialize test arguments (now uses the seed we just set) args, kwargs = deserialize_args(test_args) # Extract model initialization args init_args = model_config.get("init_args", {}).copy() - # Generate seed dynamically and set for deterministic behavior - seed = random.randint(0, 2**32 - 1) - torch.manual_seed(seed) - # Create fresh model instance model = model_class(**init_args) model.train() diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 8e3681c6..1afa6da3 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -317,12 +317,15 @@ def cli( # Add full model testing for model suite if suite.name == "model": - logger.info("\n" + "=" * 60) - logger.info("MODEL EVALUATION") - logger.info("=" * 60) + all_results = [] for model in suite.models: results = suite.eval_model(model, backend) - suite.print_results(results) + all_results.append(results) + logger.info("=" * 60) + logger.info("MODEL EVALUATION RESULTS") + logger.info("=" * 60) + for result in all_results: + suite.print_results(result) logger.info("=" * 60) command = "python -m BackendBench.scripts.main " + " ".join(sys.argv[1:]) diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py index fcd9f592..7bc40511 100644 --- a/BackendBench/suite/model.py +++ b/BackendBench/suite/model.py @@ -132,28 +132,9 @@ def get_model_ops(self, models: List[Dict[str, Any]]) -> List[str]: # Extract operators from model configs model_ops = set() for model in models: - config_ops = model["config"].get("ops") - if not config_ops: - raise ValueError(f"Model {model['name']} has no 'ops' field in config") - - # Support both list format (legacy) and dict format (forward/backward) - if isinstance(config_ops, list): - # Legacy format: ops is a flat list - ops_list = config_ops - elif isinstance(config_ops, dict): - # New format: ops is a dict with 'forward' and 'backward' keys - ops_list = [] - if "forward" in config_ops: - ops_list.extend(config_ops["forward"]) - if "backward" in config_ops: - ops_list.extend(config_ops["backward"]) - else: - raise ValueError( - f"Model {model['name']}: 'ops' must be either a list or a dict with 'forward'/'backward' keys" - ) - - if not ops_list: - raise ValueError(f"Model {model['name']}: 'ops' field is empty") + config_ops = model["config"]["ops"] + ops_list = config_ops["forward"] + ops_list.extend(config_ops["backward"]) model_ops.update(ops_list) logger.info(f"Model {model['name']}: {len(ops_list)} operators defined in config") @@ -223,9 +204,6 @@ def print_results(self, results: Dict[str, Any]) -> None: Args: results: Dictionary with evaluation results from eval_model """ - - print(results) - model_name = results.get("model_name", "Unknown") passed = results.get("passed", False) num_passed = results.get("num_passed", 0) diff --git a/BackendBench/suite/torchbench.py b/BackendBench/suite/torchbench.py index 29597636..b116d6b8 100644 --- a/BackendBench/suite/torchbench.py +++ b/BackendBench/suite/torchbench.py @@ -95,6 +95,10 @@ def _initialize_torchbench_suite( filter=filter, ) + if check_overhead_dominated_ops: + # Only include ops which are overhead dominated (this is useful as a performance canary) + ops_list = [op for op in ops_list if op.get("is_overhead_dominated_op", False)] + # Convert to dictionary format using utility function self.optests = op_list_to_benchmark_dict(ops_list) From 92b82e7f3a3dbca7092f41755e11b654cb2d6b74 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 05:16:35 +0000 Subject: [PATCH 12/16] edits --- test/test_model_ops_configs.py | 100 +++++++++++++++------------------ 1 file changed, 44 insertions(+), 56 deletions(-) diff --git a/test/test_model_ops_configs.py b/test/test_model_ops_configs.py index c5d6915f..687ff2cc 100644 --- a/test/test_model_ops_configs.py +++ b/test/test_model_ops_configs.py @@ -38,25 +38,9 @@ def load_model_ops(models: List[Dict[str, Any]]) -> Set[str]: config_ops = model["config"].get("ops") if not config_ops: raise ValueError(f"Model {model['name']} has no 'ops' field in config") - - # Support both list format (legacy) and dict format (forward/backward) - if isinstance(config_ops, list): - # Legacy format: ops is a flat list - ops_list = config_ops - elif isinstance(config_ops, dict): - # New format: ops is a dict with 'forward' and 'backward' keys - ops_list = [] - if "forward" in config_ops: - ops_list.extend(config_ops["forward"]) - if "backward" in config_ops: - ops_list.extend(config_ops["backward"]) - else: - raise ValueError( - f"Model {model['name']}: 'ops' must be either a list or a dict with 'forward'/'backward' keys" - ) - - if not ops_list: - raise ValueError(f"Model {model['name']}: 'ops' field is empty") + assert "forward" in config_ops, f"Model {model['name']} has no 'forward' field in config" + assert "backward" in config_ops, f"Model {model['name']} has no 'backward' field in config" + ops_list = config_ops["forward"] + config_ops["backward"] model_ops.update(ops_list) return model_ops @@ -133,50 +117,54 @@ def test_json_configs_have_required_fields(self): # Validate 'ops' field - can be list or dict config_ops = config["ops"] - if isinstance(config_ops, list): - # Legacy format: flat list - self.assertGreater( - len(config_ops), 0, f"Model {model_name}: 'ops' list must not be empty" + self.assertGreater( + len(config_ops["forward"] + config_ops["backward"]), + 0, + f"Model {model_name}: 'ops' list must not be empty", + ) + for op in config_ops["forward"] + config_ops["backward"]: + self.assertIsInstance( + op, str, f"Model {model_name}: each op in 'ops' must be a string" ) - for op in config_ops: - self.assertIsInstance( - op, str, f"Model {model_name}: each op in 'ops' must be a string" - ) - elif isinstance(config_ops, dict): - # New format: dict with forward/backward - self.assertTrue( - "forward" in config_ops or "backward" in config_ops, - f"Model {model_name}: 'ops' dict must contain 'forward' or 'backward' keys", + self.assertIsInstance( + config_ops["forward"], + list, + f"Model {model_name}: 'ops.forward' must be a list", + ) + for op in config_ops["forward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.forward' must be a string", + ) + self.assertIsInstance( + config_ops["backward"], + list, + f"Model {model_name}: 'ops.backward' must be a list", + ) + for op in config_ops["backward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.backward' must be a string", ) - if "forward" in config_ops: - self.assertIsInstance( - config_ops["forward"], - list, - f"Model {model_name}: 'ops.forward' must be a list", - ) - for op in config_ops["forward"]: - self.assertIsInstance( - op, - str, - f"Model {model_name}: each op in 'ops.forward' must be a string", - ) - if "backward" in config_ops: - self.assertIsInstance( - config_ops["backward"], - list, - f"Model {model_name}: 'ops.backward' must be a list", - ) - for op in config_ops["backward"]: - self.assertIsInstance( - op, - str, - f"Model {model_name}: each op in 'ops.backward' must be a string", - ) else: self.fail( f"Model {model_name}: 'ops' must be either a list or a dict with 'forward'/'backward' keys" ) + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + # Validate 'model_tests' field self.assertIsInstance( config["model_tests"], From 8aa4767a61a2fb84d4e5466823ca85d9ada5c1a2 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 05:22:25 +0000 Subject: [PATCH 13/16] edits --- test/test_model_ops_configs.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_model_ops_configs.py b/test/test_model_ops_configs.py index 687ff2cc..8b3f3dbe 100644 --- a/test/test_model_ops_configs.py +++ b/test/test_model_ops_configs.py @@ -148,10 +148,6 @@ def test_json_configs_have_required_fields(self): str, f"Model {model_name}: each op in 'ops.backward' must be a string", ) - else: - self.fail( - f"Model {model_name}: 'ops' must be either a list or a dict with 'forward'/'backward' keys" - ) # Validate 'model_tests' field self.assertIsInstance( From fbf6325bfb89cfbf04414018ca74c0f21d5bd1b7 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 05:37:15 +0000 Subject: [PATCH 14/16] edits --- BackendBench/eval_model.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/BackendBench/eval_model.py b/BackendBench/eval_model.py index 7bb5a587..a1832da6 100644 --- a/BackendBench/eval_model.py +++ b/BackendBench/eval_model.py @@ -162,32 +162,6 @@ def _move_model_to_input_device( return model -def _ensure_input_requires_grad( - args: List[Any], kwargs: Dict[str, Any] -) -> Tuple[List[Any], Dict[str, Any]]: - """Ensure input tensor has requires_grad=True for gradient computation. - - Args: - args: Positional arguments list - kwargs: Keyword arguments dict - - Returns: - Updated (args, kwargs) with input tensor requiring gradients - """ - if args and isinstance(args[0], torch.Tensor): - x = args[0] - if not x.requires_grad: - x = x.clone().detach().requires_grad_(True) - args = [x] + list(args[1:]) - elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): - x = kwargs["x"] - if not x.requires_grad: - x = x.clone().detach().requires_grad_(True) - kwargs["x"] = x - - return args, kwargs - - def _collect_gradients( model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any] ) -> List[torch.Tensor]: @@ -265,9 +239,6 @@ def _run_model( # Move model to same device as input model = _move_model_to_input_device(model, args, kwargs) - # Ensure input has requires_grad for gradient computation - args, kwargs = _ensure_input_requires_grad(args, kwargs) - # Run forward + backward with or without backend if backend_enabled: with BackendBench.BackendBench.enable(kernel_dir=kernel_dir): From f58e0d3a7a00b3c0d7f1109a0abafeac2d47cd18 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 05:46:33 +0000 Subject: [PATCH 15/16] edits --- BackendBench/suite/models/README.md | 80 +++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 BackendBench/suite/models/README.md diff --git a/BackendBench/suite/models/README.md b/BackendBench/suite/models/README.md new file mode 100644 index 00000000..8bcdac82 --- /dev/null +++ b/BackendBench/suite/models/README.md @@ -0,0 +1,80 @@ +# Adding Models to BackendBench + +## Quick Start + +Models define operator lists and validate that custom backends work correctly in full model execution. Two files required: + +``` +BackendBench/suite/models/YourModel/ +├── YourModel.py # nn.Module class +└── YourModel.json # Configuration +``` + +**Naming rule:** Directory name = File name = Class name (exact match, case-sensitive) + +## Adding a Model + +### 1. Create Directory and Files + +```bash +cd BackendBench/suite/models +mkdir MyModel +cd MyModel +touch MyModel.py MyModel.json +``` + +### 2. Write Model Class (`MyModel.py`) + +**Requirements:** +- Class name = filename (exact match) +- All `__init__` params need defaults +- Add a main() / runner if you are inclined for sanity checking +- Keep it simple - focus on specific operators you're testing +- Look in this directory for examples + +### 3. Write Config (`MyModel.json`) + +**Key Fields:** +- `model_config.init_args` - Args for `__init__()`, must match your defaults +- `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten..default"`) +- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` + - Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc. +- `metadata.description` - What this model tests +- Look in this directory for examples + +**Finding operator names:** +```python +from torch.profiler import profile, ProfilerActivity + +with profile(activities=[ProfilerActivity.CPU]) as prof: + output = model(x) + loss = output.sum() + loss.backward() + +for event in prof.key_averages(): + if "aten::" in event.key: + print(event.key) +``` + +### 4. Test Your Model + +```bash +# Test standalone +cd BackendBench/suite/models/MyModel +python MyModel.py # Add main() for standalone testing + +# Test with suite +python -m BackendBench.scripts.main \ + --suite model \ + --backend aten \ + --model-filter MyModel + +# Expected output: +# Model: MyModel +# Status: ✓ Passed (2/2 tests) +# ✓ small +# ✓ large +``` + +### 5: Validation +`test/test_model_ops_configs.py` and `test/test_model_ops_coverage.py` are tests that validate that all models are loadable / formatted correctly. From fa78b35b4b4b1669b9123b8c0110dd97552e5ebf Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 1 Oct 2025 05:48:09 +0000 Subject: [PATCH 16/16] edits --- BackendBench/suite/models/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/BackendBench/suite/models/README.md b/BackendBench/suite/models/README.md index 8bcdac82..57e707dc 100644 --- a/BackendBench/suite/models/README.md +++ b/BackendBench/suite/models/README.md @@ -37,7 +37,7 @@ touch MyModel.py MyModel.json **Key Fields:** - `model_config.init_args` - Args for `__init__()`, must match your defaults - `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten..default"`) -- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` +- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` The format is further described [here](https://huggingface.co/datasets/GPUMODE/backendbench_tests#serialized-arguments-in-backendbench) - Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc. - `metadata.description` - What this model tests - Look in this directory for examples