Skip to content

Commit daebfc1

Browse files
committed
[ModelSuite] Add Toy Models
Summary: Here we introduce model suite (model.py). The idea here to start and codify the ideas from jiannanWang/BackendBenchExamples. Specifically this PR adds some example models / configs which are to be loaded + a Readme. (It may be useful to look at the PR above this as well since it's the model loading logic). This PR adds two toy models to model suite SmokeTestModel - This is simple model that uses aten.ops.mm as we can implement a correct version of this op ToyCoreOpsModel - This is a model which explicitly calls the backwards passes which are both in torchbench + core. Test Plan: the test infra is in the pr above, so tests passing on the PR above should be sufficient here ### Future work with Model Suite #181
1 parent 65b7c1a commit daebfc1

File tree

5 files changed

+294
-0
lines changed

5 files changed

+294
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Adding Models to BackendBench
2+
3+
## Quick Start
4+
5+
Models define operator lists and validate that custom backends work correctly in full model execution. Two files required:
6+
7+
```
8+
BackendBench/suite/models/YourModel/
9+
├── YourModel.py # nn.Module class
10+
└── YourModel.json # Configuration
11+
```
12+
13+
**Naming rule:** Directory name = File name = Class name (exact match, case-sensitive)
14+
15+
## Adding a Model
16+
17+
### 1. Create Directory and Files
18+
19+
```bash
20+
cd BackendBench/suite/models
21+
mkdir MyModel
22+
cd MyModel
23+
touch MyModel.py MyModel.json
24+
```
25+
26+
### 2. Write Model Class (`MyModel.py`)
27+
28+
**Requirements:**
29+
- Class name = filename (exact match)
30+
- All `__init__` params need defaults
31+
- Add a main() / runner if you are inclined for sanity checking
32+
- Keep it simple - focus on specific operators you're testing
33+
- Look in this directory for examples
34+
35+
### 3. Write Config (`MyModel.json`)
36+
37+
**Key Fields:**
38+
- `model_config.init_args` - Args for `__init__()`, must match your defaults
39+
- `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten.<op>.default"`)
40+
- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` The format is further described [here](https://huggingface.co/datasets/GPUMODE/backendbench_tests#serialized-arguments-in-backendbench)
41+
- Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc.
42+
- `metadata.description` - What this model tests
43+
- Look in this directory for examples
44+
45+
**Finding operator names:**
46+
```python
47+
from torch.profiler import profile, ProfilerActivity
48+
49+
with profile(activities=[ProfilerActivity.CPU]) as prof:
50+
output = model(x)
51+
loss = output.sum()
52+
loss.backward()
53+
54+
for event in prof.key_averages():
55+
if "aten::" in event.key:
56+
print(event.key)
57+
```
58+
59+
### 4. Test Your Model
60+
61+
```bash
62+
# Test standalone
63+
cd BackendBench/suite/models/MyModel
64+
python MyModel.py # Add main() for standalone testing
65+
66+
# Test with suite
67+
python -m BackendBench.scripts.main \
68+
--suite model \
69+
--backend aten \
70+
--model-filter MyModel
71+
72+
# Expected output:
73+
# Model: MyModel
74+
# Status: ✓ Passed (2/2 tests)
75+
# ✓ small
76+
# ✓ large
77+
```
78+
79+
### 5: Validation
80+
`test/test_model_ops_configs.py` and `test/test_model_ops_coverage.py` are tests that validate that all models are loadable / formatted correctly.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"model_config": {
3+
"init_args": {
4+
"input_dim": 128,
5+
"hidden_dim": 128,
6+
"output_dim": 128
7+
}
8+
},
9+
"ops": {
10+
"forward": [
11+
"aten.mm.default"
12+
],
13+
"backward": [
14+
"aten.mm.default"
15+
]
16+
},
17+
"model_tests": {
18+
"small_batch": "([], {'x': T([2, 128], f32)})",
19+
"medium_batch": "([], {'x': T([16, 128], f32)})",
20+
"large_batch": "([], {'x': T([32, 128], f32)})"
21+
},
22+
"metadata": {
23+
"description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes"
24+
}
25+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Simple model that tests matrix multiplication operations using explicit
9+
torch.mm calls.
10+
"""
11+
12+
import torch
13+
import torch.nn as nn
14+
15+
16+
class SmokeTestModel(nn.Module):
17+
"""
18+
Model that uses explicit torch.mm operations to test aten.mm.default
19+
in forward/backward.
20+
"""
21+
22+
def __init__(
23+
self,
24+
input_dim: int = 128,
25+
hidden_dim: int = 128,
26+
output_dim: int = 128,
27+
):
28+
super().__init__()
29+
self.input_dim = input_dim
30+
self.hidden_dim = hidden_dim
31+
self.output_dim = output_dim
32+
33+
self.weight1 = nn.Parameter(torch.randn(input_dim, hidden_dim))
34+
self.weight2 = nn.Parameter(torch.randn(hidden_dim, output_dim))
35+
self.bias1 = nn.Parameter(torch.randn(hidden_dim))
36+
self.bias2 = nn.Parameter(torch.randn(output_dim))
37+
38+
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
"""
40+
Forward pass: (x @ weight1 + bias1) -> relu -> (x @ weight2 + bias2)
41+
"""
42+
x = torch.mm(x, self.weight1) + self.bias1
43+
x = torch.relu(x)
44+
x = torch.mm(x, self.weight2) + self.bias2
45+
return x
46+
47+
48+
def main():
49+
"""Demonstrate the model with a forward/backward pass."""
50+
model = SmokeTestModel(input_dim=128, hidden_dim=128, output_dim=128)
51+
batch_size = 4
52+
input_tensor = torch.randn(batch_size, 128, requires_grad=True)
53+
54+
model.train()
55+
output = model(input_tensor)
56+
loss = output.sum()
57+
loss.backward()
58+
59+
print("✓ Forward/backward pass completed")
60+
print(f" Parameters: {sum(p.numel() for p in model.parameters())}")
61+
print(f" Input: {input_tensor.shape} -> Output: {output.shape}")
62+
grad_count = sum(1 for p in model.parameters() if p.grad is not None)
63+
total_params = len(list(model.parameters()))
64+
print(f" Gradients computed: {grad_count}/{total_params}")
65+
66+
67+
if __name__ == "__main__":
68+
main()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"model_config": {
3+
"init_args": {
4+
"in_channels": 3,
5+
"hidden_channels": 32,
6+
"out_channels": 8,
7+
"num_groups": 8
8+
}
9+
},
10+
"ops": {
11+
"forward": [
12+
"aten.convolution.default",
13+
"aten.native_group_norm.default",
14+
"aten.max_pool2d_with_indices.default",
15+
"aten.avg_pool2d.default",
16+
"aten._adaptive_avg_pool2d.default"
17+
],
18+
"backward": [
19+
"aten.convolution_backward.default",
20+
"aten.native_group_norm_backward.default",
21+
"aten.max_pool2d_with_indices_backward.default",
22+
"aten.avg_pool2d_backward.default",
23+
"aten._adaptive_avg_pool2d_backward.default"
24+
]
25+
},
26+
"model_tests": {
27+
"small_batch": "([], {'x': T([2, 3, 32, 32], f32)})",
28+
"medium_batch": "([], {'x': T([4, 3, 64, 64], f32)})",
29+
"large_input": "([], {'x': T([2, 3, 128, 128], f32)})"
30+
},
31+
"metadata": {
32+
"description": "Core operations model testing fundamental operators: convolution, group norm, max pool with indices, avg pool, adaptive avg pool"
33+
}
34+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
CNN model that triggers core PyTorch backward operators:
9+
- convolution_backward
10+
- native_group_norm_backward
11+
- max_pool2d_with_indices_backward
12+
- avg_pool2d_backward
13+
- _adaptive_avg_pool2d_backward
14+
"""
15+
16+
import torch
17+
import torch.nn as nn
18+
import torch.nn.functional as F
19+
20+
21+
class ToyCoreOpsModel(nn.Module):
22+
"""CNN that uses conv, group norm, max pool, avg pool, and adaptive avg pool."""
23+
24+
def __init__(
25+
self,
26+
in_channels: int = 3,
27+
hidden_channels: int = 32,
28+
out_channels: int = 8,
29+
num_groups: int = 8,
30+
):
31+
super().__init__()
32+
33+
if hidden_channels % num_groups != 0:
34+
raise ValueError(
35+
f"hidden_channels ({hidden_channels}) must be divisible by "
36+
f"num_groups ({num_groups})"
37+
)
38+
39+
self.in_channels = in_channels
40+
self.hidden_channels = hidden_channels
41+
self.out_channels = out_channels
42+
self.num_groups = num_groups
43+
44+
self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1)
45+
self.group_norm1 = nn.GroupNorm(num_groups, hidden_channels)
46+
self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1)
47+
self.group_norm2 = nn.GroupNorm(num_groups, hidden_channels)
48+
self.conv_out = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
49+
50+
def forward(self, x: torch.Tensor) -> torch.Tensor:
51+
"""
52+
Forward pass through: Conv->GroupNorm->ReLU->MaxPool->Conv->
53+
GroupNorm->ReLU->AvgPool->AdaptiveAvgPool->Conv
54+
Output is always (batch, out_channels, 4, 4) regardless of
55+
input size.
56+
"""
57+
x = F.relu(self.group_norm1(self.conv1(x)))
58+
x, _ = F.max_pool2d(x, kernel_size=2, return_indices=True)
59+
x = F.relu(self.group_norm2(self.conv2(x)))
60+
x = F.avg_pool2d(x, kernel_size=2)
61+
x = F.adaptive_avg_pool2d(x, output_size=(4, 4))
62+
x = self.conv_out(x)
63+
return x
64+
65+
66+
def main():
67+
"""Demonstrate the model with a forward/backward pass."""
68+
model = ToyCoreOpsModel(in_channels=3, hidden_channels=32, out_channels=8, num_groups=8)
69+
batch_size = 2
70+
input_tensor = torch.randn(batch_size, 3, 64, 64, requires_grad=True)
71+
72+
model.train()
73+
output = model(input_tensor)
74+
loss = output.sum()
75+
loss.backward()
76+
77+
print("✓ Forward/backward pass completed")
78+
print(f" Parameters: {sum(p.numel() for p in model.parameters())}")
79+
print(f" Input: {input_tensor.shape} -> Output: {output.shape}")
80+
grad_count = sum(1 for p in model.parameters() if p.grad is not None)
81+
total_params = len(list(model.parameters()))
82+
print(f" Gradients computed: {grad_count}/{total_params}")
83+
return model
84+
85+
86+
if __name__ == "__main__":
87+
main()

0 commit comments

Comments
 (0)