-
Notifications
You must be signed in to change notification settings - Fork 15
Add model backend, add tests #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8aa5f7c
843cc97
f2e1f61
20f1889
2ab91b7
36e44c2
228d8e1
1791b8f
b06fa37
cebd34c
b0b224b
92b82e7
8aa4767
fbf6325
f58e0d3
fa78b35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,257 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """Model-level evaluation utilities for testing full model correctness.""" | ||
|
|
||
| import logging | ||
| import random | ||
| import traceback | ||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, List, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| import BackendBench | ||
| from BackendBench.eval import allclose | ||
| from BackendBench.utils import deserialize_args | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelCorrectnessTestResult: | ||
| """Result from testing a model configuration.""" | ||
|
|
||
| model_name: str | ||
| test_name: str | ||
| is_correct: bool = False | ||
| error_msg: str = "" | ||
| error_type: str = "" | ||
| traceback: str = "" | ||
| output_match: bool = False | ||
| gradients_match: bool = False | ||
| num_gradients: int = 0 | ||
|
|
||
|
|
||
| def eval_model_correctness_test( | ||
| model_name: str, | ||
| model_class: type, | ||
| model_config: Dict[str, Any], | ||
| test_name: str, | ||
| test_args: str, | ||
| kernel_dir: str = None, | ||
| atol: float = 1e-2, | ||
| rtol: float = 1e-2, | ||
| ) -> ModelCorrectnessTestResult: | ||
| """Evaluate model correctness by comparing eager vs backend execution. | ||
|
|
||
| Similar to eval_correctness_test in eval.py, but for full models instead of individual ops. | ||
|
|
||
| Args: | ||
| model_name: Name of the model being tested | ||
| model_class: Model class to instantiate | ||
| model_config: Model configuration dict with init_args | ||
| test_name: Name of this test configuration | ||
| test_args: Serialized arguments string for forward pass | ||
| kernel_dir: Optional directory containing kernels for backend | ||
| atol: Absolute tolerance for allclose | ||
| rtol: Relative tolerance for allclose | ||
|
|
||
| Returns: | ||
| ModelCorrectnessTestResult with detailed comparison results | ||
| """ | ||
| try: | ||
| # Generate a single seed to use for both eager and backend runs | ||
| # This ensures both runs use the same model initialization | ||
| seed = random.randint(0, 2**32 - 1) | ||
|
|
||
| # Run in eager mode (reference) | ||
| eager_out, eager_grads = _run_model( | ||
| model_class, | ||
| model_config, | ||
| test_args, | ||
| backend_enabled=False, | ||
| kernel_dir=kernel_dir, | ||
| seed=seed, | ||
| ) | ||
|
|
||
| # Run with backend (implementation) | ||
| backend_out, backend_grads = _run_model( | ||
| model_class, | ||
| model_config, | ||
| test_args, | ||
| backend_enabled=True, | ||
| kernel_dir=kernel_dir, | ||
| seed=seed, | ||
| ) | ||
|
|
||
| # Compare outputs | ||
| output_match = allclose(eager_out, backend_out, atol=atol, rtol=rtol) | ||
|
|
||
| # Compare gradients | ||
| gradients_match = True | ||
| if len(eager_grads) != len(backend_grads): | ||
| gradients_match = False | ||
| else: | ||
| for eager_grad, backend_grad in zip(eager_grads, backend_grads): | ||
| if not allclose(eager_grad, backend_grad, atol=atol, rtol=rtol): | ||
| gradients_match = False | ||
| break | ||
|
|
||
| is_correct = output_match and gradients_match | ||
|
|
||
| return ModelCorrectnessTestResult( | ||
| model_name=model_name, | ||
| test_name=test_name, | ||
| is_correct=is_correct, | ||
| output_match=output_match, | ||
| gradients_match=gradients_match, | ||
| num_gradients=len(eager_grads), | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| error_msg = f"Model {model_name}::{test_name} failed: {e}" | ||
| logger.error(error_msg) | ||
| return ModelCorrectnessTestResult( | ||
| model_name=model_name, | ||
| test_name=test_name, | ||
| is_correct=False, | ||
| error_msg=error_msg, | ||
| error_type=str(type(e)), | ||
| traceback=traceback.format_exc(), | ||
| ) | ||
|
|
||
|
|
||
| def _get_input_tensor(args: List[Any], kwargs: Dict[str, Any]) -> torch.Tensor: | ||
| """Extract input tensor from args or kwargs. | ||
|
|
||
| Args: | ||
| args: Positional arguments list | ||
| kwargs: Keyword arguments dict | ||
|
|
||
| Returns: | ||
| Input tensor if found, None otherwise | ||
| """ | ||
| if args and isinstance(args[0], torch.Tensor): | ||
| return args[0] | ||
| elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor): | ||
| return kwargs["x"] | ||
PaliC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return None | ||
|
|
||
|
|
||
| def _move_model_to_input_device( | ||
| model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any] | ||
| ) -> torch.nn.Module: | ||
| """Move model to the same device as input tensor. | ||
|
|
||
| Args: | ||
| model: Model to move | ||
| args: Positional arguments list | ||
| kwargs: Keyword arguments dict | ||
|
|
||
| Returns: | ||
| Model on input device (or original model if no input tensor found) | ||
| """ | ||
| input_tensor = _get_input_tensor(args, kwargs) | ||
| if input_tensor is not None: | ||
| device = input_tensor.device | ||
| model = model.to(device) | ||
| return model | ||
|
|
||
|
|
||
| def _collect_gradients( | ||
| model: torch.nn.Module, args: List[Any], kwargs: Dict[str, Any] | ||
| ) -> List[torch.Tensor]: | ||
| """Collect gradients from input and model parameters. | ||
|
|
||
| Args: | ||
| model: Model with computed gradients | ||
| args: Positional arguments list | ||
| kwargs: Keyword arguments dict | ||
|
|
||
| Returns: | ||
| List of gradient tensors [input_grad, param1_grad, ...] | ||
| """ | ||
| grads = [] | ||
|
|
||
| # Input gradient - check both args and kwargs | ||
| input_grad = None | ||
| if args and isinstance(args[0], torch.Tensor) and args[0].grad is not None: | ||
| input_grad = args[0].grad | ||
| elif "x" in kwargs and isinstance(kwargs["x"], torch.Tensor) and kwargs["x"].grad is not None: | ||
| input_grad = kwargs["x"].grad | ||
|
|
||
| if input_grad is not None: | ||
| grads.append(input_grad.clone()) | ||
|
|
||
| # Parameter gradients | ||
| for param in model.parameters(): | ||
| if param.grad is not None: | ||
| grads.append(param.grad.clone()) | ||
|
|
||
| return grads | ||
|
|
||
|
|
||
| def _run_model( | ||
| model_class: type, | ||
| model_config: Dict[str, Any], | ||
| test_args: str, | ||
| backend_enabled: bool, | ||
| kernel_dir: str = "generated_kernels", | ||
| seed: int = None, | ||
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | ||
| """Run model with or without backend enabled. | ||
|
|
||
| Args: | ||
| model_class: Model class to instantiate | ||
| model_config: Model configuration dict with init_args | ||
| test_args: Serialized arguments string for forward pass | ||
| backend_enabled: If True, use BackendBench context manager | ||
| kernel_dir: Optional directory containing kernels | ||
| seed: Random seed for reproducibility. If None, generates a random seed. | ||
|
|
||
| Returns: | ||
| Tuple of (output, gradients) where: | ||
| - output: Model output tensor (detached) | ||
| - gradients: List of gradient tensors [input_grad, param1_grad, ...] | ||
| """ | ||
|
|
||
| # Generate seed dynamically and set for deterministic behavior | ||
| # IMPORTANT: Must set seed BEFORE deserializing args, because deserialization | ||
| # may create random tensors! | ||
| if seed is None: | ||
| seed = random.randint(0, 2**32 - 1) | ||
| torch.manual_seed(seed) | ||
|
|
||
| # Deserialize test arguments (now uses the seed we just set) | ||
| args, kwargs = deserialize_args(test_args) | ||
|
|
||
| # Extract model initialization args | ||
| init_args = model_config.get("init_args", {}).copy() | ||
|
|
||
| # Create fresh model instance | ||
| model = model_class(**init_args) | ||
| model.train() | ||
|
|
||
| # Move model to same device as input | ||
| model = _move_model_to_input_device(model, args, kwargs) | ||
|
|
||
| # Run forward + backward with or without backend | ||
| if backend_enabled: | ||
| with BackendBench.BackendBench.enable(kernel_dir=kernel_dir): | ||
PaliC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| output = model(*args, **kwargs) | ||
| loss = output.sum() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my experience this might be troublesome if add or sum kernels are incorrect/do not consider all input dtypes.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that should be fine atm as that's how directorybench is structured. If we change directorybench to work with overloads, this should still work as ops are defined by their overload in the configs. |
||
| loss.backward() | ||
| else: | ||
| # Run in eager mode (no backend) | ||
| output = model(*args, **kwargs) | ||
| loss = output.sum() | ||
| loss.backward() | ||
|
|
||
| # Collect gradients | ||
| grads = _collect_gradients(model, args, kwargs) | ||
|
|
||
| return output.detach(), grads | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can see that this function tests the forward pass (output) and the backward pass (grad). What's missing here is the parameter update (optim.step). Is there a reason to not test that? E.g., is it out of the scope of our kernel registration?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking a bit about this, I think it's out of scope for kernel registration as assuming the gradients and output are correct, the optimizer should perform the same thing, and when we do add numerical correctness over training it should be accounted for. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need kernel_dir here given backend_enabled is False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair it's a bit bug prone.