diff --git a/BackendBench/kernel_templates.py b/BackendBench/kernel_templates.py index b41cc13..07a39d1 100644 --- a/BackendBench/kernel_templates.py +++ b/BackendBench/kernel_templates.py @@ -32,6 +32,24 @@ def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> """Create a prompt for kernel generation.""" raise NotImplementedError + def create_backward_prompt(self, op_name: str, op_signature: str, op_description: str) -> str: + """ + Create a prompt for backward (gradient) kernel generation. + + Default implementation returns a conservative instruction that asks for a + backward kernel implementing gradients for the forward operation. Subclasses + should override to provide DSL-specific guidance and examples. + """ + return ( + f"Generate a backward (gradient) kernel implementation for the operation " + f"'{op_name}'.\n\nSignature: {op_signature}\n\nDescription: {op_description}\n\n" + "The backward kernel should accept gradient(s) of the outputs and return " + "gradients w.r.t. each input and any trainable parameters. Be explicit " + "about shapes and dtype handling. If trainable parameters exist, update " + "or accumulate their gradients in-place or follow the standard autograd " + "convention for the target DSL." + ) + class TritonKernelTemplate(KernelTemplate): """Template for Triton kernel generation.""" @@ -56,6 +74,29 @@ def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> example=example, ) + def create_backward_prompt(self, op_name: str, op_signature: str, op_description: str) -> str: + """Triton-specific backward kernel prompt using same optimization hints.""" + optimizations = self._get_optimizations(op_name) + example = self._get_example_template(op_name) + + extra_prompt = ( + "\n\n# NOTE: The code above should be adapted to implement gradients. " + "Provide a Triton kernel (or auxiliary kernels) that computes gradients " + "w.r.t. inputs and parameters given gradient(s) of the outputs. Declare " + "the expected gradient shapes and any in-place updates for parameter grads." + ) + + return ( + TRITON_KERNEL_PROMPT.format( + op_name=op_name, + op_signature=op_signature, + op_description=op_description, + optimizations=optimizations, + example=example, + ) + + extra_prompt + ) + def _get_optimizations(self, op_name: str) -> str: """Get operation-specific optimization guidelines.""" return TRITON_OPTIMIZATIONS.get(op_name, TRITON_OPTIMIZATIONS["default"]) @@ -78,6 +119,21 @@ def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> op_name=op_name, op_signature=op_signature, op_description=op_description ) + def create_backward_prompt(self, op_name: str, op_signature: str, op_description: str) -> str: + """PyTorch-specific backward prompt: ask for autograd-friendly backward code.""" + extra_prompt = ( + "\n\n# BACKWARD: Provide a backward function (e.g., a Function.backward or " + "a gradient function) that computes gradients w.r.t. inputs and parameters. " + "Prefer returning gradients as Tensors in the same order as inputs." + ) + + return ( + PYTORCH_KERNEL_PROMPT.format( + op_name=op_name, op_signature=op_signature, op_description=op_description + ) + + extra_prompt + ) + class CuTeDSLKernelTemplate(KernelTemplate): """Template for CuTeDSL kernel generation.""" @@ -102,6 +158,26 @@ def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> example=example, ) + def create_backward_prompt(self, op_name: str, op_signature: str, op_description: str) -> str: + """CuTeDSL-specific backward prompt using CuTeDSL optimization hints.""" + optimizations = self._get_optimizations(op_name) + example = self._get_example_template(op_name) + + extra_prompt = ( + "\n\n# BACKWARD: Provide gradient computation for the above forward operator." + ) + + return ( + CUTEDSL_KERNEL_PROMPT.format( + op_name=op_name, + op_signature=op_signature, + op_description=op_description, + optimizations=optimizations, + example=example, + ) + + extra_prompt + ) + def _get_optimizations(self, op_name: str) -> str: """Get operation-specific optimization guidelines.""" return CUTEDSL_OPTIMIZATIONS.get(op_name, CUTEDSL_OPTIMIZATIONS["default"]) @@ -139,6 +215,13 @@ def create_prompt( template = self.get_template(dsl) return template.create_prompt(op_name, op_signature, op_description) + def create_backward_prompt( + self, op_name: str, op_signature: str, op_description: str, dsl: str = "triton" + ) -> str: + """Create a backward prompt using the specified template.""" + template = self.get_template(dsl) + return template.create_backward_prompt(op_name, op_signature, op_description) + def create_refinement_prompt( self, op_name: str, diff --git a/BackendBench/opregistry.py b/BackendBench/opregistry.py index 7de3a06..35d2938 100644 --- a/BackendBench/opregistry.py +++ b/BackendBench/opregistry.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging +from typing import Any, Callable, Dict, Optional import torch @@ -30,7 +31,7 @@ def _extract_spec_name_from_op(op_obj): class OpRegistry: def __init__(self): - self._registry = {} + self._registry: Dict[str, Any] = {} def get_operator(self, input_obj): if isinstance(input_obj, str): @@ -41,7 +42,11 @@ def get_operator(self, input_obj): def _get_operator_from_spec_name(self, spec_name): # Return cached operator if available if spec_name in self._registry: - return self._registry[spec_name] + entry = self._registry[spec_name] + # If entry is a kernel dict, return forward for compatibility + if isinstance(entry, dict) and "forward" in entry: + return entry["forward"] + return entry # Parse spec name op_parts = spec_name.split(".") @@ -67,7 +72,10 @@ def _get_operator_from_object(self, op_obj): # Check if we already have this operator registered if spec_name in self._registry: - return self._registry[spec_name] + entry = self._registry[spec_name] + # If entry is a kernel dict, return forward for compatibility + if isinstance(entry, dict) and "forward" in entry: + return entry["forward"] # Register the provided operator object self._registry[spec_name] = op_obj @@ -77,6 +85,39 @@ def _get_operator_from_object(self, op_obj): def register_operator(self, op_obj): return self._get_operator_from_object(op_obj) + def register_kernel( + self, + spec_name: str, + forward: Callable, + *, + backward: Optional[Callable] = None, + param_update: Optional[Callable] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + self._registry[spec_name] = { + "forward": forward, + "backward": backward, + "param_update": param_update, + "metadata": metadata or {}, + } + + def get_kernel(self, spec_name: str) -> Dict[str, Any]: + if spec_name not in self._registry: + raise KeyError(f"Operator {spec_name} is not registered") + entry = self._registry[spec_name] + if isinstance(entry, dict) and "forward" in entry: + return entry + # legacy operator object present -> wrap as forward-only kernel + return {"forward": entry, "backward": None, "param_update": None, "metadata": {}} + + def has_backward(self, spec_name: str) -> bool: + entry = self._registry.get(spec_name) + if not entry: + return False + if isinstance(entry, dict): + return entry.get("backward") is not None + return False + def get_all_registered_ops(self): return self._registry.copy() @@ -106,5 +147,22 @@ def register_operator(op_obj): return _op_registry.register_operator(op_obj) -def get_registry(): - return _op_registry +def register_kernel( + spec_name: str, + forward: Callable, + *, + backward: Optional[Callable] = None, + param_update: Optional[Callable] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + return _op_registry.register_kernel( + spec_name, forward, backward=backward, param_update=param_update, metadata=metadata + ) + + +def get_kernel(spec_name: str) -> Dict[str, Any]: + return _op_registry.get_kernel(spec_name) + + +def has_backward(spec_name: str) -> bool: + return _op_registry.has_backward(spec_name) diff --git a/BackendBench/train.py b/BackendBench/train.py new file mode 100644 index 0000000..e8b3226 --- /dev/null +++ b/BackendBench/train.py @@ -0,0 +1,232 @@ +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple, List + +import torch +from BackendBench.opregistry import get_operator + +class TrainingTestCase: + """Simple container for a single training test case.""" + inputs: Tuple[Any, ...] + target: Optional[torch.Tensor] = None + params: Optional[List[torch.Tensor]] = None # parameters to update (if any) + loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None + +@dataclass +class TrainingTestSuite: + """Collection of training test cases for an operator.""" + op: Any + training_tests: List[TrainingTestCase] + +def _mse_loss(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return torch.mean((output - target) ** 2) + +def _compute_numerical_grads(func: Callable, inputs: Tuple[torch.Tensor, ...], target: torch.Tensor, loss_fn: Callable, eps: float = 1e-3) -> List[Optional[torch.Tensor]]: + grads = [] + for inp in inputs: + if not torch.is_tensor(inp) or inp.numel() == 0: + grads.append(None) + continue + + inp = inp.detach() + base = inp.clone().reshape(-1) + grad_flat = torch.zeros_like(base) + + for i in range(base.numel()): + orig = base[i].item() + base[i] = orig + eps + inp_plus = base.reshape(inp.shape).to(inp.device) + inputs_plus = [] + for v in inputs: + inputs_plus.append(inp_plus if v is inp else (v.clone().detach() if torch.is_tensor(v) else v)) + with torch.no_grad(): + out_plus = func(*tuple(inputs_plus)) + loss_plus = loss_fn(out_plus, target).item() + + base[i] = orig - eps + inp_minus = base.reshape(inp.shape).to(inp.device) + inputs_minus = [] + for v in inputs: + inputs_minus.append(inp_minus if v is inp else (v.clone().detach() if torch.is_tensor(v) else v)) + with torch.no_grad(): + out_minus = func(*tuple(inputs_minus)) + loss_minus = loss_fn(out_minus, target).item() + + grad_flat[i] = (loss_plus - loss_minus) / (2 * eps) + base[i] = orig # restore + + grads.append(grad_flat.reshape(inp.shape)) + return grads + +def train_one_op(op: Any, kernel_impl: Callable, training_case: TrainingTestCase, *, lr: float = 1e-3, num_steps: int = 1, use_kernel_backward: bool = True, reference_op: Optional[Any] = None) -> Dict[str, Any]: + """ + Run a small training loop for one op / kernel. + + - op: operator descriptor (for logging/reference) + - kernel_impl: callable implementing forward (and possibly backward) + - training_case: TrainingTestCase with inputs/target/params + - lr: SGD learning rate applied to training_case.params (in-place) + - num_steps: number of training steps to run (default 1) + - use_kernel_backward: whether to attempt kernel's backward/autograd first + - reference_op: optional reference operator (callable) used to compute reference gradients via autograd + + Returns metrics: { + 'grad_correct': bool, + 'grad_rel_error': float, + 'step_time_ms': float, + 'converged': bool (optional), + 'final_loss': float, + } + """ + inputs = list(training_case.inputs) + target = training_case.target + params = training_case.params if training_case.params is not None else [] + loss_fn = training_case.loss_fn if training_case.loss_fn is not None else _mse_loss + + device = None + for t in inputs + params: + if torch.is_tensor(t): + device = t.device + break + if device is None: + device = torch.device("cuda:0") + + # ensure tensors are float and on correct device + for i, v in enumerate(inputs): + if torch.is_tensor(v): + inputs[i] = v.detach().to(device).clone().requires_grad_(True) + for i, p in enumerate(params): + if torch.is_tensor(p): + params[i] = p.detach().to(device).clone().requires_grad_(True) + + # reference operator resolution + ref_op = None + if reference_op is not None: + ref_op = reference_op + else: + try: + ref_op = get_operator(op) + except Exception: + ref_op = None + + # run one or more steps and measure time + t0 = time.time() + last_loss = None + grad_rel_error = 0.0 + grad_correct = False + + for step in range(num_steps): + # Zero grads + for t in inputs + params: + if torch.is_tensor(t) and t.grad is not None: + t.grad.zero_() + + # Forward + outputs = kernel_impl(*tuple(inputs)) + if isinstance(outputs, tuple): + output = outputs[0] + else: + output = outputs + + if target is None: + # If no target, attempt to create target from reference op outputs + if ref_op is not None: + with torch.no_grad(): + ref_out = ref_op(*[v.detach() for v in inputs]) + target = ref_out.detach() + else: + raise ValueError("No target provided and no reference op available to synthesize a target.") + + loss = loss_fn(output, target) + last_loss = loss.item() + + # Attempt kernel/backward autograd first + kernel_produced_grads = None + try: + if use_kernel_backward: + # try to compute gradients through kernel_impl + grads = torch.autograd.grad(loss, [t for t in inputs + params if torch.is_tensor(t)], retain_graph=False, allow_unused=True) + kernel_produced_grads = grads + except Exception: + kernel_produced_grads = None + + # Compute reference gradients (prefer autograd through reference op) + ref_grads = None + try: + if ref_op is not None: + # reconstruct inputs with requires_grad for reference + ref_inputs = [] + for v in inputs: + if torch.is_tensor(v): + ref_inputs.append(v.detach().clone().requires_grad_(True)) + else: + ref_inputs.append(v) + ref_params = [] + for p in params: + if torch.is_tensor(p): + ref_params.append(p.detach().clone().requires_grad_(True)) + else: + ref_params.append(p) + + ref_out = ref_op(*tuple(ref_inputs)) + if isinstance(ref_out, tuple): + ref_out = ref_out[0] + ref_loss = loss_fn(ref_out, target.detach().to(ref_out.device)) + ref_grads = torch.autograd.grad(ref_loss, [t for t in ref_inputs + ref_params if torch.is_tensor(t)], allow_unused=True) + except Exception: + ref_grads = None + + # If kernel gradients aren't available, try numerical finite-diff on kernel + if kernel_produced_grads is None: + try: + kernel_numerical = _compute_numerical_grads(lambda *args: kernel_impl(*args), tuple([v.detach() for v in inputs]), target.detach(), loss_fn) + kernel_produced_grads = tuple(kernel_numerical) if kernel_numerical is not None else None + except Exception: + kernel_produced_grads = None + + # Compare gradients if we have both kernel and reference + if kernel_produced_grads is not None and ref_grads is not None: + # align lists: only tensors + klist = [g for g in kernel_produced_grads if g is not None] + rlist = [g for g in ref_grads if g is not None] + if len(klist) == len(rlist) and len(klist) > 0: + rel_errors = [] + for kg, rg in zip(klist, rlist): + if kg is None or rg is None: + continue + # ensure same device + rg = rg.detach().to(kg.device) + denom = torch.max(rg.abs(), torch.tensor(1e-6, device=rg.device)) + rel = torch.max((kg.detach() - rg).abs() / denom).item() + rel_errors.append(rel) + grad_rel_error = max(rel_errors) if rel_errors else float("inf") + grad_correct = grad_rel_error < 1e-2 # tolerance + else: + # couldn't align grads -> mark as not correct + grad_rel_error = float("inf") + grad_correct = False + else: + grad_rel_error = float("inf") + grad_correct = False + + if params: + for i, p in enumerate(params): + if torch.is_tensor(p) and p.grad is not None: + with torch.no_grad(): + p -= lr * p.grad + + else: + # if no params, update inputs if they require grad + for i, v in enumerate(inputs): + if torch.is_tensor(v) and v.grad is not None: + with torch.no_grad(): + inputs[i] = (v - lr * v.grad).detach().requires_grad_(True) + + step_time_ms = (time.time() - t0) * 1000.0 / max(1, num_steps) + + return { + "grad_correct": bool(grad_correct), + "grad_rel_error": float(grad_rel_error), + "step_time_ms": float(step_time_ms), + "final_loss": float(last_loss) if last_loss is not None else None, + } \ No newline at end of file diff --git a/test/test_train.py b/test/test_train.py new file mode 100644 index 0000000..ee162d0 --- /dev/null +++ b/test/test_train.py @@ -0,0 +1,120 @@ +import torch + +from BackendBench import opregistry as _opregistry_module +from BackendBench.opregistry import register_kernel, get_kernel, has_backward +from BackendBench.train import train_one_op, TrainingTestCase + + +def test_register_and_get_kernel_and_has_backward(): + def forward(x): + return x + 1 + + def backward(x): + return x - 1 + + # register kernel and validate retrieval + register_kernel("myop.default", forward, backward=backward) + kernel = get_kernel("myop.default") + assert kernel["forward"] is forward + assert kernel["backward"] is backward + assert has_backward("myop.default") is True + + # cleanup registry to avoid test interaction + _opregistry_module._op_registry.clear() + + +def test_has_backward_false_when_no_backward(): + def fwd(x): + return x * 2.0 + + register_kernel("noback.default", fwd, backward=None) + assert has_backward("noback.default") is False + _opregistry_module._op_registry.clear() + + +def test_train_one_op_gradients_match_reference(): + # simple kernel that is identical to the reference op + def kernel_impl(x): + return x * 3.0 + + def reference_op(x): + return x * 3.0 + + tc = TrainingTestCase() + tc.inputs = (torch.tensor([1.0, 2.0], dtype=torch.float32),) + tc.target = torch.tensor([3.0, 6.0], dtype=torch.float32) + tc.params = None + tc.loss_fn = None # use default mse + + res = train_one_op( + op="dummy_op", + kernel_impl=kernel_impl, + training_case=tc, + lr=1e-3, + num_steps=1, + use_kernel_backward=True, + reference_op=reference_op, + ) + + assert res["grad_correct"] is True + assert res["grad_rel_error"] < 1e-2 + assert float(res["final_loss"]) == 0.0 + + +def test_train_one_op_gradients_mismatch(): + # kernel produces different gradients than reference -> should be flagged + def kernel_impl(x): + return x * 2.0 + + def reference_op(x): + return x * 3.0 + + tc = TrainingTestCase() + tc.inputs = (torch.tensor([1.0, 2.0], dtype=torch.float32),) + tc.target = torch.tensor([3.0, 6.0], dtype=torch.float32) + tc.params = None + tc.loss_fn = None + + res = train_one_op( + op="dummy_op_mismatch", + kernel_impl=kernel_impl, + training_case=tc, + lr=1e-3, + num_steps=1, + use_kernel_backward=True, + reference_op=reference_op, + ) + + assert res["grad_correct"] is False + assert res["grad_rel_error"] > 0.0 + assert float(res["final_loss"]) > 0.0 + + +def test_train_one_op_numerical_gradients_fallback(): + # kernel_impl uses detach so autograd won't produce grads, forcing numerical finite-diff + def kernel_impl_detached(x): + return x.detach() * 3.0 + + def reference_op(x): + return x * 3.0 + + tc = TrainingTestCase() + tc.inputs = (torch.tensor([1.0, 2.0], dtype=torch.float32),) + tc.target = torch.tensor([3.0, 6.0], dtype=torch.float32) + tc.params = None + tc.loss_fn = None + + res = train_one_op( + op="dummy_op_numerical", + kernel_impl=kernel_impl_detached, + training_case=tc, + lr=1e-3, + num_steps=1, + use_kernel_backward=True, # will attempt autograd first but fall back to numerical + reference_op=reference_op, + ) + + # numerical grads should match reference autograd grads for this simple function + assert res["grad_correct"] is True + assert res["grad_rel_error"] < 1e-2 + assert float(res["final_loss"]) == 0.0 \ No newline at end of file