Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions BackendBench/kernel_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ def create_prompt(self, op_name: str, op_signature: str, op_description: str) ->
"""Create a prompt for kernel generation."""
raise NotImplementedError

def create_backward_prompt(self, op_name: str, op_signature: str, op_description: str) -> str:
"""
Create a prompt for backward (gradient) kernel generation.

Default implementation returns a conservative instruction that asks for a
backward kernel implementing gradients for the forward operation. Subclasses
should override to provide DSL-specific guidance and examples.
"""
return (
f"Generate a backward (gradient) kernel implementation for the operation "
f"'{op_name}'.\n\nSignature: {op_signature}\n\nDescription: {op_description}\n\n"
"The backward kernel should accept gradient(s) of the outputs and return "
"gradients w.r.t. each input and any trainable parameters. Be explicit "
"about shapes and dtype handling. If trainable parameters exist, update "
"or accumulate their gradients in-place or follow the standard autograd "
"convention for the target DSL."
)


class TritonKernelTemplate(KernelTemplate):
"""Template for Triton kernel generation."""
Expand All @@ -56,6 +74,29 @@ def create_prompt(self, op_name: str, op_signature: str, op_description: str) ->
example=example,
)

def create_backward_prompt(self, op_name: str, op_signature: str, op_description: str) -> str:
"""Triton-specific backward kernel prompt using same optimization hints."""
optimizations = self._get_optimizations(op_name)
example = self._get_example_template(op_name)

extra_prompt = (
"\n\n# NOTE: The code above should be adapted to implement gradients. "
"Provide a Triton kernel (or auxiliary kernels) that computes gradients "
"w.r.t. inputs and parameters given gradient(s) of the outputs. Declare "
"the expected gradient shapes and any in-place updates for parameter grads."
)

return (
TRITON_KERNEL_PROMPT.format(
op_name=op_name,
op_signature=op_signature,
op_description=op_description,
optimizations=optimizations,
example=example,
)
+ extra_prompt
)

def _get_optimizations(self, op_name: str) -> str:
"""Get operation-specific optimization guidelines."""
return TRITON_OPTIMIZATIONS.get(op_name, TRITON_OPTIMIZATIONS["default"])
Expand All @@ -78,6 +119,21 @@ def create_prompt(self, op_name: str, op_signature: str, op_description: str) ->
op_name=op_name, op_signature=op_signature, op_description=op_description
)

def create_backward_prompt(self, op_name: str, op_signature: str, op_description: str) -> str:
"""PyTorch-specific backward prompt: ask for autograd-friendly backward code."""
extra_prompt = (
"\n\n# BACKWARD: Provide a backward function (e.g., a Function.backward or "
"a gradient function) that computes gradients w.r.t. inputs and parameters. "
"Prefer returning gradients as Tensors in the same order as inputs."
)

return (
PYTORCH_KERNEL_PROMPT.format(
op_name=op_name, op_signature=op_signature, op_description=op_description
)
+ extra_prompt
)


class CuTeDSLKernelTemplate(KernelTemplate):
"""Template for CuTeDSL kernel generation."""
Expand All @@ -102,6 +158,26 @@ def create_prompt(self, op_name: str, op_signature: str, op_description: str) ->
example=example,
)

def create_backward_prompt(self, op_name: str, op_signature: str, op_description: str) -> str:
"""CuTeDSL-specific backward prompt using CuTeDSL optimization hints."""
optimizations = self._get_optimizations(op_name)
example = self._get_example_template(op_name)

extra_prompt = (
"\n\n# BACKWARD: Provide gradient computation for the above forward operator."
)

return (
CUTEDSL_KERNEL_PROMPT.format(
op_name=op_name,
op_signature=op_signature,
op_description=op_description,
optimizations=optimizations,
example=example,
)
+ extra_prompt
)

def _get_optimizations(self, op_name: str) -> str:
"""Get operation-specific optimization guidelines."""
return CUTEDSL_OPTIMIZATIONS.get(op_name, CUTEDSL_OPTIMIZATIONS["default"])
Expand Down Expand Up @@ -139,6 +215,13 @@ def create_prompt(
template = self.get_template(dsl)
return template.create_prompt(op_name, op_signature, op_description)

def create_backward_prompt(
self, op_name: str, op_signature: str, op_description: str, dsl: str = "triton"
) -> str:
"""Create a backward prompt using the specified template."""
template = self.get_template(dsl)
return template.create_backward_prompt(op_name, op_signature, op_description)

def create_refinement_prompt(
self,
op_name: str,
Expand Down
68 changes: 63 additions & 5 deletions BackendBench/opregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Callable, Dict, Optional

import torch

Expand All @@ -30,7 +31,7 @@ def _extract_spec_name_from_op(op_obj):

class OpRegistry:
def __init__(self):
self._registry = {}
self._registry: Dict[str, Any] = {}

def get_operator(self, input_obj):
if isinstance(input_obj, str):
Expand All @@ -41,7 +42,11 @@ def get_operator(self, input_obj):
def _get_operator_from_spec_name(self, spec_name):
# Return cached operator if available
if spec_name in self._registry:
return self._registry[spec_name]
entry = self._registry[spec_name]
# If entry is a kernel dict, return forward for compatibility
if isinstance(entry, dict) and "forward" in entry:
return entry["forward"]
return entry

# Parse spec name
op_parts = spec_name.split(".")
Expand All @@ -67,7 +72,10 @@ def _get_operator_from_object(self, op_obj):

# Check if we already have this operator registered
if spec_name in self._registry:
return self._registry[spec_name]
entry = self._registry[spec_name]
# If entry is a kernel dict, return forward for compatibility
if isinstance(entry, dict) and "forward" in entry:
return entry["forward"]

# Register the provided operator object
self._registry[spec_name] = op_obj
Expand All @@ -77,6 +85,39 @@ def _get_operator_from_object(self, op_obj):
def register_operator(self, op_obj):
return self._get_operator_from_object(op_obj)

def register_kernel(
self,
spec_name: str,
forward: Callable,
*,
backward: Optional[Callable] = None,
param_update: Optional[Callable] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
self._registry[spec_name] = {
"forward": forward,
"backward": backward,
"param_update": param_update,
"metadata": metadata or {},
}

def get_kernel(self, spec_name: str) -> Dict[str, Any]:
if spec_name not in self._registry:
raise KeyError(f"Operator {spec_name} is not registered")
entry = self._registry[spec_name]
if isinstance(entry, dict) and "forward" in entry:
return entry
# legacy operator object present -> wrap as forward-only kernel
return {"forward": entry, "backward": None, "param_update": None, "metadata": {}}

def has_backward(self, spec_name: str) -> bool:
entry = self._registry.get(spec_name)
if not entry:
return False
if isinstance(entry, dict):
return entry.get("backward") is not None
return False

def get_all_registered_ops(self):
return self._registry.copy()

Expand Down Expand Up @@ -106,5 +147,22 @@ def register_operator(op_obj):
return _op_registry.register_operator(op_obj)


def get_registry():
return _op_registry
def register_kernel(
spec_name: str,
forward: Callable,
*,
backward: Optional[Callable] = None,
param_update: Optional[Callable] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
return _op_registry.register_kernel(
spec_name, forward, backward=backward, param_update=param_update, metadata=metadata
)


def get_kernel(spec_name: str) -> Dict[str, Any]:
return _op_registry.get_kernel(spec_name)


def has_backward(spec_name: str) -> bool:
return _op_registry.has_backward(spec_name)
Loading