diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 479e5805..2240e2d1 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -19,6 +19,7 @@ from BackendBench.output import save_results from BackendBench.suite import ( FactoTestSuite, + ModelSuite, OpInfoTestSuite, SmokeTestSuite, TorchBenchTestSuite, @@ -50,7 +51,7 @@ def setup_logging(log_level): @click.option( "--suite", default="smoke", - type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]), + type=click.Choice(["smoke", "opinfo", "torchbench", "facto", "model"]), help="Which suite to run", ) @click.option( @@ -63,7 +64,13 @@ def setup_logging(log_level): "--ops", default=None, type=str, - help="Comma-separated list of ops to run", + help="Comma-separated list of ops to run (not supported for model suite)", +) +@click.option( + "--model-filter", + default=None, + type=str, + help="Comma-separated list of models to run (only for model suite)", ) @click.option( "--topn-inputs", @@ -147,6 +154,7 @@ def cli( suite, backend, ops, + model_filter, topn_inputs, llm_attempts, llm_model, @@ -166,9 +174,22 @@ def cli( if check_overhead_dominated_ops: raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite") + if suite == "model": + if ops is not None: + raise ValueError( + "--ops filter is not supported for model suite. Use --model-filter instead" + ) + # remove this in later PR as model suite is supported + raise NotImplementedError("Model suite is not supported yet") + + if suite != "model" and model_filter is not None: + raise ValueError("--model-filter is only supported for model suite") + setup_logging(log_level) if ops: ops = ops.split(",") + if model_filter: + model_filter = model_filter.split(",") suite = { "smoke": lambda: SmokeTestSuite, @@ -191,6 +212,7 @@ def cli( torch.bfloat16, filter=ops, ), + "model": lambda: ModelSuite(filter=model_filter), }[suite]() backend_name = backend diff --git a/BackendBench/suite/__init__.py b/BackendBench/suite/__init__.py index 410a5d6e..e9c332a9 100644 --- a/BackendBench/suite/__init__.py +++ b/BackendBench/suite/__init__.py @@ -15,6 +15,7 @@ from .base import OpTest, Test, TestSuite from .facto import FactoTestSuite +from .model import ModelSuite from .opinfo import OpInfoTestSuite from .smoke import randn, SmokeTestSuite from .torchbench import TorchBenchOpTest, TorchBenchTestSuite @@ -24,6 +25,7 @@ "OpTest", "TestSuite", "FactoTestSuite", + "ModelSuite", "OpInfoTestSuite", "SmokeTestSuite", "randn", diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py new file mode 100644 index 00000000..12f98256 --- /dev/null +++ b/BackendBench/suite/model.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Model Suite for testing models defined in configs. +""" + +import importlib.util +import json +import logging +import os +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +def load_models( + models_dir: str = "models", filter: Optional[List[str]] = None +) -> List[Dict[str, Any]]: + """Load models using strict naming convention: folder_name/folder_name.py + folder_name.json + + Args: + models_dir: Directory containing models (default: "models") + filter: Optional list of model names to load. If None, loads all models. + + Returns: + List of dictionaries with keys: + - name: Model name (str) + - class: Model class (type) + - config: Configuration dictionary from JSON file + """ + models = [] + + if not os.path.exists(models_dir): + raise FileNotFoundError(f"Models directory not found: {models_dir}") + + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + # Skip if not in filter + if filter is not None and model_name not in filter: + continue + + # Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json + model_file = os.path.join(model_dir, f"{model_name}.py") + config_file = os.path.join(model_dir, f"{model_name}.json") + + # Check both files exist + if not os.path.exists(model_file): + raise FileNotFoundError(f"Model file not found: {model_file}") + + if not os.path.exists(config_file): + raise FileNotFoundError(f"Config file not found: {config_file}") + + try: + # Load config + with open(config_file, "r") as f: + config = json.load(f) + + # Load model class dynamically + spec = importlib.util.spec_from_file_location(model_name, model_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find model class (must match model_name exactly) + if not hasattr(module, model_name): + raise RuntimeError(f"Model class '{model_name}' not found in {model_file}") + + model_class = getattr(module, model_name) + if not (isinstance(model_class, type) and hasattr(model_class, "forward")): + raise RuntimeError(f"'{model_name}' in {model_file} is not a valid model class") + + models.append({"name": model_name, "class": model_class, "config": config}) + logger.info(f"Loaded model: {model_name}") + + except Exception as e: + raise RuntimeError(f"Failed to load model {model_name}: {e}") + + if filter is not None and len(models) == 0: + raise ValueError(f"No models found matching filter: {filter}") + + return models + + +class ModelSuite: + """Model Suite for end-to-end model testing.""" + + def __init__( + self, + name: str = "model", + filter: Optional[List[str]] = None, + ): + """Initialize ModelSuite. + + Args: + name: Suite name (default: "model") + filter: Optional list of model names to load + """ + models_dir = os.path.join(os.path.dirname(__file__), "models") + + # Load models + models = load_models(models_dir=models_dir, filter=filter) + logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}") + + # Store loaded models + self.models = models + self.name = name diff --git a/BackendBench/suite/models/README.md b/BackendBench/suite/models/README.md new file mode 100644 index 00000000..57e707dc --- /dev/null +++ b/BackendBench/suite/models/README.md @@ -0,0 +1,80 @@ +# Adding Models to BackendBench + +## Quick Start + +Models define operator lists and validate that custom backends work correctly in full model execution. Two files required: + +``` +BackendBench/suite/models/YourModel/ +├── YourModel.py # nn.Module class +└── YourModel.json # Configuration +``` + +**Naming rule:** Directory name = File name = Class name (exact match, case-sensitive) + +## Adding a Model + +### 1. Create Directory and Files + +```bash +cd BackendBench/suite/models +mkdir MyModel +cd MyModel +touch MyModel.py MyModel.json +``` + +### 2. Write Model Class (`MyModel.py`) + +**Requirements:** +- Class name = filename (exact match) +- All `__init__` params need defaults +- Add a main() / runner if you are inclined for sanity checking +- Keep it simple - focus on specific operators you're testing +- Look in this directory for examples + +### 3. Write Config (`MyModel.json`) + +**Key Fields:** +- `model_config.init_args` - Args for `__init__()`, must match your defaults +- `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten..default"`) +- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` The format is further described [here](https://huggingface.co/datasets/GPUMODE/backendbench_tests#serialized-arguments-in-backendbench) + - Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc. +- `metadata.description` - What this model tests +- Look in this directory for examples + +**Finding operator names:** +```python +from torch.profiler import profile, ProfilerActivity + +with profile(activities=[ProfilerActivity.CPU]) as prof: + output = model(x) + loss = output.sum() + loss.backward() + +for event in prof.key_averages(): + if "aten::" in event.key: + print(event.key) +``` + +### 4. Test Your Model + +```bash +# Test standalone +cd BackendBench/suite/models/MyModel +python MyModel.py # Add main() for standalone testing + +# Test with suite +python -m BackendBench.scripts.main \ + --suite model \ + --backend aten \ + --model-filter MyModel + +# Expected output: +# Model: MyModel +# Status: ✓ Passed (2/2 tests) +# ✓ small +# ✓ large +``` + +### 5: Validation +`test/test_model_ops_configs.py` and `test/test_model_ops_coverage.py` are tests that validate that all models are loadable / formatted correctly. diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json new file mode 100644 index 00000000..b7d286ae --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json @@ -0,0 +1,25 @@ +{ + "model_config": { + "init_args": { + "input_dim": 128, + "hidden_dim": 128, + "output_dim": 128 + } + }, + "ops": { + "forward": [ + "aten.mm.default" + ], + "backward": [ + "aten.mm.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 128], f32)})", + "medium_batch": "([], {'x': T([16, 128], f32)})", + "large_batch": "([], {'x': T([32, 128], f32)})" + }, + "metadata": { + "description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes" + } +} diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py new file mode 100644 index 00000000..3bf627e4 --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple model that tests matrix multiplication operations using explicit +torch.mm calls. +""" + +import torch +import torch.nn as nn + + +class SmokeTestModel(nn.Module): + """ + Model that uses explicit torch.mm operations to test aten.mm.default + in forward/backward. + """ + + def __init__( + self, + input_dim: int = 128, + hidden_dim: int = 128, + output_dim: int = 128, + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + self.weight1 = nn.Parameter(torch.randn(input_dim, hidden_dim)) + self.weight2 = nn.Parameter(torch.randn(hidden_dim, output_dim)) + self.bias1 = nn.Parameter(torch.randn(hidden_dim)) + self.bias2 = nn.Parameter(torch.randn(output_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass: (x @ weight1 + bias1) -> relu -> (x @ weight2 + bias2) + """ + x = torch.mm(x, self.weight1) + self.bias1 + x = torch.relu(x) + x = torch.mm(x, self.weight2) + self.bias2 + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = SmokeTestModel(input_dim=128, hidden_dim=128, output_dim=128) + batch_size = 4 + input_tensor = torch.randn(batch_size, 128, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + + +if __name__ == "__main__": + main() diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json new file mode 100644 index 00000000..1586273e --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json @@ -0,0 +1,34 @@ +{ + "model_config": { + "init_args": { + "in_channels": 3, + "hidden_channels": 32, + "out_channels": 8, + "num_groups": 8 + } + }, + "ops": { + "forward": [ + "aten.convolution.default", + "aten.native_group_norm.default", + "aten.max_pool2d_with_indices.default", + "aten.avg_pool2d.default", + "aten._adaptive_avg_pool2d.default" + ], + "backward": [ + "aten.convolution_backward.default", + "aten.native_group_norm_backward.default", + "aten.max_pool2d_with_indices_backward.default", + "aten.avg_pool2d_backward.default", + "aten._adaptive_avg_pool2d_backward.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 3, 32, 32], f32)})", + "medium_batch": "([], {'x': T([4, 3, 64, 64], f32)})", + "large_input": "([], {'x': T([2, 3, 128, 128], f32)})" + }, + "metadata": { + "description": "Core operations model testing fundamental operators: convolution, group norm, max pool with indices, avg pool, adaptive avg pool" + } +} \ No newline at end of file diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py new file mode 100644 index 00000000..410e4c4f --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +CNN model that triggers core PyTorch backward operators: +- convolution_backward +- native_group_norm_backward +- max_pool2d_with_indices_backward +- avg_pool2d_backward +- _adaptive_avg_pool2d_backward +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ToyCoreOpsModel(nn.Module): + """CNN that uses conv, group norm, max pool, avg pool, and adaptive avg pool.""" + + def __init__( + self, + in_channels: int = 3, + hidden_channels: int = 32, + out_channels: int = 8, + num_groups: int = 8, + ): + super().__init__() + + if hidden_channels % num_groups != 0: + raise ValueError( + f"hidden_channels ({hidden_channels}) must be divisible by " + f"num_groups ({num_groups})" + ) + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.num_groups = num_groups + + self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm1 = nn.GroupNorm(num_groups, hidden_channels) + self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm2 = nn.GroupNorm(num_groups, hidden_channels) + self.conv_out = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through: Conv->GroupNorm->ReLU->MaxPool->Conv-> + GroupNorm->ReLU->AvgPool->AdaptiveAvgPool->Conv + Output is always (batch, out_channels, 4, 4) regardless of + input size. + """ + x = F.relu(self.group_norm1(self.conv1(x))) + x, _ = F.max_pool2d(x, kernel_size=2, return_indices=True) + x = F.relu(self.group_norm2(self.conv2(x))) + x = F.avg_pool2d(x, kernel_size=2) + x = F.adaptive_avg_pool2d(x, output_size=(4, 4)) + x = self.conv_out(x) + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = ToyCoreOpsModel(in_channels=3, hidden_channels=32, out_channels=8, num_groups=8) + batch_size = 2 + input_tensor = torch.randn(batch_size, 3, 64, 64, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + return model + + +if __name__ == "__main__": + main() diff --git a/test/test_model_ops_configs.py b/test/test_model_ops_configs.py new file mode 100644 index 00000000..8b3f3dbe --- /dev/null +++ b/test/test_model_ops_configs.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unit test to verify that ModelSuite's operator filter correctly matches +the operators defined in model configs. + +This test validates that: +1. load_models correctly loads model configs from the models directory +2. load_model_ops extracts the correct set of operators from model configs +3. TorchBenchTestSuite initialized with those operators has matching optests +4. JSON config files have proper format with required fields +""" + +import json +import os +import unittest +from typing import Any, Dict, List, Set + +from BackendBench.suite.model import load_models +from BackendBench.suite.torchbench import TorchBenchTestSuite + + +def load_model_ops(models: List[Dict[str, Any]]) -> Set[str]: + """Extract unique set of operators from model configs. + + Args: + models: List of model dictionaries with 'name', 'class', and 'config' keys + + Returns: + Set of operator names defined across all model configs + """ + model_ops = set() + for model in models: + config_ops = model["config"].get("ops") + if not config_ops: + raise ValueError(f"Model {model['name']} has no 'ops' field in config") + assert "forward" in config_ops, f"Model {model['name']} has no 'forward' field in config" + assert "backward" in config_ops, f"Model {model['name']} has no 'backward' field in config" + ops_list = config_ops["forward"] + config_ops["backward"] + + model_ops.update(ops_list) + return model_ops + + +class TestModelOpsConfigs(unittest.TestCase): + """Test that model ops filter correctly initializes TorchBenchTestSuite.""" + + def test_model_ops_match_suite_optests(self): + """Test that suite's optests match the operators from model configs.""" + # Get the models directory path (same as ModelSuite does) + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load models using load_models + models = load_models(models_dir=models_dir) + + # Verify we loaded at least one model + self.assertGreater(len(models), 0, "Should load at least one model") + + # Extract operators from model configs using load_model_ops + model_ops = load_model_ops(models) + + # Verify we have operators + self.assertGreater(len(model_ops), 0, "Should have at least one operator") + + # Create filter list from model ops + ops_filter = list(model_ops) + + # Initialize TorchBenchTestSuite with the filter + suite = TorchBenchTestSuite( + name="test_model_ops", + filename=None, # Use default HuggingFace dataset + filter=ops_filter, + topn=None, + ) + + # Get the set of operators in the suite's optests + suite_ops = set(suite.optests.keys()) + + # The suite_ops should be a subset of model_ops because: + # - model_ops is the filter we requested + # - suite_ops contains only those operators that exist in the TorchBench dataset + # - Not all operators in model configs may be in the dataset + self.assertTrue( + suite_ops.issubset(model_ops), + f"Suite operators {suite_ops} should be subset of model operators {model_ops}", + ) + + # Verify that suite actually has some operators + self.assertGreater( + len(suite_ops), 0, "Suite should contain at least one operator from model configs" + ) + + def test_json_configs_have_required_fields(self): + """Test that all JSON config files have proper format with required fields.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load all models + models = load_models(models_dir=models_dir) + + for model in models: + model_name = model["name"] + config = model["config"] + + # Check required top-level fields + self.assertIn("ops", config, f"Model {model_name}: config must have 'ops' field") + self.assertIn( + "model_tests", config, f"Model {model_name}: config must have 'model_tests' field" + ) + + # Validate 'ops' field - can be list or dict + config_ops = config["ops"] + self.assertGreater( + len(config_ops["forward"] + config_ops["backward"]), + 0, + f"Model {model_name}: 'ops' list must not be empty", + ) + for op in config_ops["forward"] + config_ops["backward"]: + self.assertIsInstance( + op, str, f"Model {model_name}: each op in 'ops' must be a string" + ) + self.assertIsInstance( + config_ops["forward"], + list, + f"Model {model_name}: 'ops.forward' must be a list", + ) + for op in config_ops["forward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.forward' must be a string", + ) + self.assertIsInstance( + config_ops["backward"], + list, + f"Model {model_name}: 'ops.backward' must be a list", + ) + for op in config_ops["backward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.backward' must be a string", + ) + + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + for test_name, test_args in config["model_tests"].items(): + self.assertIsInstance( + test_name, str, f"Model {model_name}: test names must be strings" + ) + self.assertIsInstance( + test_args, str, f"Model {model_name}: test args must be strings" + ) + + # Check optional but recommended fields + if "model_config" in config: + self.assertIsInstance( + config["model_config"], + dict, + f"Model {model_name}: 'model_config' must be a dictionary if present", + ) + + def test_json_files_are_valid_json(self): + """Test that all JSON config files are valid JSON and can be parsed.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Find all JSON files in the models directory + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + json_file = os.path.join(model_dir, f"{model_name}.json") + if not os.path.exists(json_file): + continue + + # Try to parse the JSON file + with open(json_file, "r") as f: + try: + config = json.load(f) + self.assertIsInstance( + config, + dict, + f"JSON file {json_file} must contain a dictionary at top level", + ) + except json.JSONDecodeError as e: + self.fail(f"JSON file {json_file} is not valid JSON: {e}") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_model_suite.py b/test/test_model_suite.py new file mode 100644 index 00000000..12ddaf1b --- /dev/null +++ b/test/test_model_suite.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for Model Suite: Filtered TorchBench operators from model tracing + +This test suite validates: +1. Model loading from toy_models directory +2. Operator extraction via model tracing +3. ModelSuite creates filtered TorchBench suite +""" + +import logging +import unittest + +from BackendBench.suite.model import load_models + +# Setup logging +logging.basicConfig(level=logging.WARNING) + + +class TestModelLoading(unittest.TestCase): + """Test model loading functionality.""" + + def test_load_models(self): + """Test that models can be loaded from directory.""" + models = load_models(models_dir="BackendBench/suite/models") + self.assertGreater(len(models), 0, "Should load at least one model") + + # Verify model structure + for model in models: + self.assertIn("name", model) + self.assertIn("class", model) + self.assertIn("config", model) + + def test_load_specific_model(self): + """Test loading a specific model by name.""" + models = load_models(models_dir="BackendBench/suite/models", filter=["ToyCoreOpsModel"]) + self.assertEqual(len(models), 1) + self.assertEqual(models[0]["name"], "ToyCoreOpsModel") + + def test_invalid_filter(self): + """Test that invalid filter raises error.""" + with self.assertRaises(ValueError): + load_models(models_dir="BackendBench/suite/models", filter=["nonexistent"]) + + +if __name__ == "__main__": + # Run tests + unittest.main(verbosity=2)