diff --git a/BackendBench/suite/models/README.md b/BackendBench/suite/models/README.md new file mode 100644 index 00000000..57e707dc --- /dev/null +++ b/BackendBench/suite/models/README.md @@ -0,0 +1,80 @@ +# Adding Models to BackendBench + +## Quick Start + +Models define operator lists and validate that custom backends work correctly in full model execution. Two files required: + +``` +BackendBench/suite/models/YourModel/ +├── YourModel.py # nn.Module class +└── YourModel.json # Configuration +``` + +**Naming rule:** Directory name = File name = Class name (exact match, case-sensitive) + +## Adding a Model + +### 1. Create Directory and Files + +```bash +cd BackendBench/suite/models +mkdir MyModel +cd MyModel +touch MyModel.py MyModel.json +``` + +### 2. Write Model Class (`MyModel.py`) + +**Requirements:** +- Class name = filename (exact match) +- All `__init__` params need defaults +- Add a main() / runner if you are inclined for sanity checking +- Keep it simple - focus on specific operators you're testing +- Look in this directory for examples + +### 3. Write Config (`MyModel.json`) + +**Key Fields:** +- `model_config.init_args` - Args for `__init__()`, must match your defaults +- `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten..default"`) +- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` The format is further described [here](https://huggingface.co/datasets/GPUMODE/backendbench_tests#serialized-arguments-in-backendbench) + - Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc. +- `metadata.description` - What this model tests +- Look in this directory for examples + +**Finding operator names:** +```python +from torch.profiler import profile, ProfilerActivity + +with profile(activities=[ProfilerActivity.CPU]) as prof: + output = model(x) + loss = output.sum() + loss.backward() + +for event in prof.key_averages(): + if "aten::" in event.key: + print(event.key) +``` + +### 4. Test Your Model + +```bash +# Test standalone +cd BackendBench/suite/models/MyModel +python MyModel.py # Add main() for standalone testing + +# Test with suite +python -m BackendBench.scripts.main \ + --suite model \ + --backend aten \ + --model-filter MyModel + +# Expected output: +# Model: MyModel +# Status: ✓ Passed (2/2 tests) +# ✓ small +# ✓ large +``` + +### 5: Validation +`test/test_model_ops_configs.py` and `test/test_model_ops_coverage.py` are tests that validate that all models are loadable / formatted correctly. diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json new file mode 100644 index 00000000..b7d286ae --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json @@ -0,0 +1,25 @@ +{ + "model_config": { + "init_args": { + "input_dim": 128, + "hidden_dim": 128, + "output_dim": 128 + } + }, + "ops": { + "forward": [ + "aten.mm.default" + ], + "backward": [ + "aten.mm.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 128], f32)})", + "medium_batch": "([], {'x': T([16, 128], f32)})", + "large_batch": "([], {'x': T([32, 128], f32)})" + }, + "metadata": { + "description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes" + } +} diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py new file mode 100644 index 00000000..3bf627e4 --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple model that tests matrix multiplication operations using explicit +torch.mm calls. +""" + +import torch +import torch.nn as nn + + +class SmokeTestModel(nn.Module): + """ + Model that uses explicit torch.mm operations to test aten.mm.default + in forward/backward. + """ + + def __init__( + self, + input_dim: int = 128, + hidden_dim: int = 128, + output_dim: int = 128, + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + self.weight1 = nn.Parameter(torch.randn(input_dim, hidden_dim)) + self.weight2 = nn.Parameter(torch.randn(hidden_dim, output_dim)) + self.bias1 = nn.Parameter(torch.randn(hidden_dim)) + self.bias2 = nn.Parameter(torch.randn(output_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass: (x @ weight1 + bias1) -> relu -> (x @ weight2 + bias2) + """ + x = torch.mm(x, self.weight1) + self.bias1 + x = torch.relu(x) + x = torch.mm(x, self.weight2) + self.bias2 + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = SmokeTestModel(input_dim=128, hidden_dim=128, output_dim=128) + batch_size = 4 + input_tensor = torch.randn(batch_size, 128, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + + +if __name__ == "__main__": + main() diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json new file mode 100644 index 00000000..1586273e --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json @@ -0,0 +1,34 @@ +{ + "model_config": { + "init_args": { + "in_channels": 3, + "hidden_channels": 32, + "out_channels": 8, + "num_groups": 8 + } + }, + "ops": { + "forward": [ + "aten.convolution.default", + "aten.native_group_norm.default", + "aten.max_pool2d_with_indices.default", + "aten.avg_pool2d.default", + "aten._adaptive_avg_pool2d.default" + ], + "backward": [ + "aten.convolution_backward.default", + "aten.native_group_norm_backward.default", + "aten.max_pool2d_with_indices_backward.default", + "aten.avg_pool2d_backward.default", + "aten._adaptive_avg_pool2d_backward.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 3, 32, 32], f32)})", + "medium_batch": "([], {'x': T([4, 3, 64, 64], f32)})", + "large_input": "([], {'x': T([2, 3, 128, 128], f32)})" + }, + "metadata": { + "description": "Core operations model testing fundamental operators: convolution, group norm, max pool with indices, avg pool, adaptive avg pool" + } +} \ No newline at end of file diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py new file mode 100644 index 00000000..410e4c4f --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +CNN model that triggers core PyTorch backward operators: +- convolution_backward +- native_group_norm_backward +- max_pool2d_with_indices_backward +- avg_pool2d_backward +- _adaptive_avg_pool2d_backward +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ToyCoreOpsModel(nn.Module): + """CNN that uses conv, group norm, max pool, avg pool, and adaptive avg pool.""" + + def __init__( + self, + in_channels: int = 3, + hidden_channels: int = 32, + out_channels: int = 8, + num_groups: int = 8, + ): + super().__init__() + + if hidden_channels % num_groups != 0: + raise ValueError( + f"hidden_channels ({hidden_channels}) must be divisible by " + f"num_groups ({num_groups})" + ) + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.num_groups = num_groups + + self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm1 = nn.GroupNorm(num_groups, hidden_channels) + self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm2 = nn.GroupNorm(num_groups, hidden_channels) + self.conv_out = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through: Conv->GroupNorm->ReLU->MaxPool->Conv-> + GroupNorm->ReLU->AvgPool->AdaptiveAvgPool->Conv + Output is always (batch, out_channels, 4, 4) regardless of + input size. + """ + x = F.relu(self.group_norm1(self.conv1(x))) + x, _ = F.max_pool2d(x, kernel_size=2, return_indices=True) + x = F.relu(self.group_norm2(self.conv2(x))) + x = F.avg_pool2d(x, kernel_size=2) + x = F.adaptive_avg_pool2d(x, output_size=(4, 4)) + x = self.conv_out(x) + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = ToyCoreOpsModel(in_channels=3, hidden_channels=32, out_channels=8, num_groups=8) + batch_size = 2 + input_tensor = torch.randn(batch_size, 3, 64, 64, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + return model + + +if __name__ == "__main__": + main()