Skip to content

Commit 6161729

Browse files
authored
Enable CuTeDSL kernel generation (#190)
1 parent 4468ee2 commit 6161729

File tree

5 files changed

+258
-23
lines changed

5 files changed

+258
-23
lines changed

BackendBench/backends/llm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def _kernel_feedback_loop(
461461
op_name: str,
462462
op_signature: str,
463463
op_description: str,
464-
framework: str = "triton",
464+
dsl: str = "triton",
465465
attempts: int = 5,
466466
) -> Tuple[str, int, bool]:
467467
"""
@@ -473,7 +473,7 @@ def _kernel_feedback_loop(
473473
op_name: Name of the operation for which to generate a kernel.
474474
op_signature: Function signature of the operation.
475475
op_description: Detailed description of the operation.
476-
framework: Target framework for the kernel (default: "triton").
476+
dsl: Target DSL for the kernel (default: "triton").
477477
attempts: Maximum number of generation attempts (default: 5).
478478
479479
Returns:
@@ -498,7 +498,7 @@ def _kernel_feedback_loop(
498498

499499
try:
500500
kernel_code = self.llm_client.generate_kernel(
501-
op_name, op_signature, op_description, framework, feedback_str
501+
op_name, op_signature, op_description, dsl, feedback_str
502502
)
503503
except Exception as e:
504504
logger.info(f" ✗ Failed to generate kernel: {e}")
@@ -570,7 +570,7 @@ def _kernel_feedback_loop(
570570
best_kernel_feedback_info.is_correct,
571571
)
572572

573-
def generate_kernels(self, suite, attempts=5):
573+
def generate_kernels(self, suite, attempts=5, dsl="triton"):
574574
"""Generate kernels for all operators in the suite with comprehensive feedback."""
575575
successful_ops = 0
576576
total_ops = 0
@@ -590,6 +590,7 @@ def generate_kernels(self, suite, attempts=5):
590590
op_name=op_name,
591591
op_signature=f"def {op_name}(*args, **kwargs) -> torch.Tensor",
592592
op_description=f"PyTorch operation: {op_name}",
593+
dsl=dsl,
593594
attempts=attempts,
594595
)
595596

BackendBench/kernel_templates.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from typing import Dict
1212

1313
from .prompts import (
14+
CUTEDSL_EXAMPLE_TEMPLATES,
15+
CUTEDSL_KERNEL_PROMPT,
16+
CUTEDSL_OPTIMIZATIONS,
1417
PYTORCH_KERNEL_PROMPT,
1518
TRITON_EXAMPLE_TEMPLATES,
1619
TRITON_KERNEL_PROMPT,
@@ -21,9 +24,9 @@
2124
class KernelTemplate:
2225
"""Base class for kernel templates."""
2326

24-
def __init__(self, name: str, framework: str):
27+
def __init__(self, name: str, dsl: str):
2528
self.name = name
26-
self.framework = framework
29+
self.dsl = dsl
2730

2831
def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str:
2932
"""Create a prompt for kernel generation."""
@@ -76,43 +79,76 @@ def create_prompt(self, op_name: str, op_signature: str, op_description: str) ->
7679
)
7780

7881

82+
class CuTeDSLKernelTemplate(KernelTemplate):
83+
"""Template for CuTeDSL kernel generation."""
84+
85+
def __init__(self):
86+
super().__init__("cutedsl", "cutedsl")
87+
88+
def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str:
89+
"""Create a specialized prompt for CuTeDSL kernel generation."""
90+
91+
# Get operation-specific optimizations
92+
optimizations = self._get_optimizations(op_name)
93+
94+
# Get example template
95+
example = self._get_example_template(op_name)
96+
97+
return CUTEDSL_KERNEL_PROMPT.format(
98+
op_name=op_name,
99+
op_signature=op_signature,
100+
op_description=op_description,
101+
optimizations=optimizations,
102+
example=example,
103+
)
104+
105+
def _get_optimizations(self, op_name: str) -> str:
106+
"""Get operation-specific optimization guidelines."""
107+
return CUTEDSL_OPTIMIZATIONS.get(op_name, CUTEDSL_OPTIMIZATIONS["default"])
108+
109+
def _get_example_template(self, op_name: str) -> str:
110+
"""Get operation-specific code template."""
111+
return CUTEDSL_EXAMPLE_TEMPLATES["default"]
112+
113+
79114
class KernelTemplateManager:
80-
"""Manages kernel templates for different frameworks."""
115+
"""Manages kernel templates for different dsls."""
81116

82117
def __init__(self):
83118
self.templates: Dict[str, KernelTemplate] = {
84119
"triton": TritonKernelTemplate(),
85120
"pytorch": PyTorchKernelTemplate(),
121+
"cutedsl": CuTeDSLKernelTemplate(),
86122
# TODO: Add cuda, cutile, whatever we want
87123
}
88124

89-
def get_template(self, framework: str) -> KernelTemplate:
90-
"""Get template for specified framework."""
91-
if framework not in self.templates:
92-
raise ValueError(f"Unknown framework: {framework}")
93-
return self.templates[framework]
125+
def get_template(self, dsl: str) -> KernelTemplate:
126+
"""Get template for specified dsl."""
127+
if dsl not in self.templates:
128+
raise ValueError(f"Unknown dsl: {dsl}")
129+
return self.templates[dsl]
94130

95131
def create_prompt(
96132
self,
97133
op_name: str,
98134
op_signature: str,
99135
op_description: str,
100-
framework: str = "triton",
136+
dsl: str = "triton",
101137
) -> str:
102138
"""Create a prompt using the specified template."""
103-
template = self.get_template(framework)
139+
template = self.get_template(dsl)
104140
return template.create_prompt(op_name, op_signature, op_description)
105141

106142
def create_refinement_prompt(
107143
self,
108144
op_name: str,
109145
op_signature: str,
110146
op_description: str,
111-
framework: str = "triton",
147+
dsl: str = "triton",
112148
feedback: str = "",
113149
) -> str:
114150
"""Create a refinement prompt with feedback from previous attempts."""
115-
base_prompt = self.create_prompt(op_name, op_signature, op_description, framework)
151+
base_prompt = self.create_prompt(op_name, op_signature, op_description, dsl)
116152

117153
if feedback and feedback.strip():
118154
refinement_prompt = f"""{feedback}

BackendBench/llm_client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,15 @@ def generate_kernel(
7575
op_name: str,
7676
op_signature: str,
7777
op_description: str,
78-
framework: str = "triton",
78+
dsl: str = "triton",
7979
feedback: Optional[str] = None,
8080
) -> str:
8181
if feedback:
8282
prompt = self.template_manager.create_refinement_prompt(
83-
op_name, op_signature, op_description, framework, feedback
83+
op_name, op_signature, op_description, dsl, feedback
8484
)
8585
else:
86-
prompt = self.template_manager.create_prompt(
87-
op_name, op_signature, op_description, framework
88-
)
86+
prompt = self.template_manager.create_prompt(op_name, op_signature, op_description, dsl)
8987

9088
print("\n=== DEBUG: PROMPT SENT TO LLM RELAY ===")
9189
print(prompt)

BackendBench/prompts.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,196 @@
4242
}
4343

4444
TRITON_EXAMPLE_TEMPLATES = {"default": "See main prompt for example structure."}
45+
46+
CUTEDSL_KERNEL_PROMPT = """Generate a CuteDSL kernel for: {op_name}
47+
48+
Operation: {op_signature}
49+
{op_description}
50+
51+
Requirements:
52+
- CuteDSL kernel function MUST be named: {op_name}_cutedsl_kernel
53+
- Launcher function MUST be named: {op_name}_kernel_launch
54+
- Wrapper function MUST be named: {op_name}_kernel_impl
55+
- Use modern CuteDSL syntax with proper grid computation
56+
- Include all necessary imports (torch, cutlass, cutlass.cute as cute)
57+
58+
The {op_name}_kernel_impl wrapper function MUST handle complete device management:
59+
- Move CPU tensors to GPU if needed (use .cuda() when torch.cuda.is_available())
60+
- Raise clear errors if CUDA is not available for GPU tensors
61+
- Call the CuteDSL kernel with GPU tensors
62+
- Move results back to original device of input tensors
63+
- Handle both args and kwargs properly
64+
- Preserve original tensor devices and restore them for outputs
65+
- Avoid falling back to PyTorch implementation
66+
- Avoid using try except block
67+
68+
Generate complete, runnable code only - no framework will add device handling wrapper code.
69+
70+
Example:
71+
{example}
72+
"""
73+
74+
CUTEDSL_OPTIMIZATIONS = {
75+
"default": "Use efficient memory access patterns and appropriate block sizes."
76+
}
77+
78+
CUTEDSL_EXAMPLE_TEMPLATES = {
79+
"default": """import torch
80+
import cutlass
81+
import cutlass.cute as cute
82+
from cutlass.cute.runtime import from_dlpack
83+
84+
@cute.kernel
85+
def add_tensor_kernel(
86+
gA: cute.Tensor,
87+
gB: cute.Tensor,
88+
gC: cute.Tensor,
89+
):
90+
tidx, _, _ = cute.arch.thread_idx()
91+
bidx, _, _ = cute.arch.block_idx()
92+
bdim, _, _ = cute.arch.block_dim()
93+
94+
thread_idx = bidx * bdim + tidx
95+
96+
# Map thread index to logical index of input tensor
97+
total_elements = gA.shape[0]
98+
99+
# Bounds checking
100+
if thread_idx < total_elements:
101+
102+
# Map logical index to physical address via tensor layout
103+
a_val = gA[thread_idx]
104+
b_val = gB[thread_idx]
105+
106+
# Perform element-wise addition
107+
gC[thread_idx] = a_val + b_val
108+
109+
@cute.kernel
110+
def add_scalar_kernel(
111+
gA: cute.Tensor,
112+
gC: cute.Tensor,
113+
scalar_val,
114+
):
115+
tidx, _, _ = cute.arch.thread_idx()
116+
bidx, _, _ = cute.arch.block_idx()
117+
bdim, _, _ = cute.arch.block_dim()
118+
119+
thread_idx = bidx * bdim + tidx
120+
121+
# Map thread index to logical index of input tensor
122+
total_elements = gA.shape[0]
123+
124+
# Bounds checking
125+
if thread_idx < total_elements:
126+
127+
# Map logical index to physical address via tensor layout
128+
a_val = gA[thread_idx]
129+
130+
# Perform element-wise addition with scalar
131+
gC[thread_idx] = a_val + scalar_val
132+
133+
@cute.jit
134+
def add_tensor_kernel_launch(
135+
mA: cute.Tensor,
136+
mB: cute.Tensor,
137+
mC: cute.Tensor
138+
):
139+
num_threads_per_block = 1024
140+
141+
total_elements = mA.shape[0]
142+
num_blocks = (total_elements + num_threads_per_block - 1) // num_threads_per_block
143+
144+
kernel = add_tensor_kernel(mA, mB, mC)
145+
kernel.launch(grid=(num_blocks, 1, 1),
146+
block=(num_threads_per_block, 1, 1))
147+
148+
@cute.jit
149+
def add_scalar_kernel_launch(
150+
mA: cute.Tensor,
151+
mC: cute.Tensor,
152+
scalar_val
153+
):
154+
num_threads_per_block = 1024
155+
156+
total_elements = mA.shape[0]
157+
num_blocks = (total_elements + num_threads_per_block - 1) // num_threads_per_block
158+
159+
kernel = add_scalar_kernel(mA, mC, scalar_val)
160+
kernel.launch(grid=(num_blocks, 1, 1),
161+
block=(num_threads_per_block, 1, 1))
162+
163+
def add_kernel_impl(*args, **kwargs):
164+
165+
# Handle both positional and keyword arguments
166+
if len(args) >= 2:
167+
input_tensor = args[0]
168+
other = args[1]
169+
elif len(args) == 1 and 'other' in kwargs:
170+
input_tensor = args[0]
171+
other = kwargs['other']
172+
elif 'input' in kwargs and 'other' in kwargs:
173+
input_tensor = kwargs['input']
174+
other = kwargs['other']
175+
else:
176+
raise ValueError("add requires 'input' and 'other' arguments")
177+
178+
if torch.is_tensor(other):
179+
input_tensor, other = torch.broadcast_tensors(input_tensor, other)
180+
181+
if 'alpha' in kwargs:
182+
alpha = kwargs['alpha']
183+
other = other * alpha
184+
185+
# Remember original device
186+
original_device = input_tensor.device
187+
188+
# Flatten all tensors and save their shapes
189+
original_shape = input_tensor.shape
190+
input_tensor = input_tensor.flatten()
191+
if torch.is_tensor(other):
192+
other = other.flatten()
193+
194+
# Move to GPU if needed
195+
if not input_tensor.is_cuda:
196+
if not torch.cuda.is_available():
197+
raise RuntimeError("CUDA is not available")
198+
input_tensor = input_tensor.cuda()
199+
200+
# Check if other is a tensor or scalar
201+
if torch.is_tensor(other):
202+
# Tensor + Tensor case
203+
if not other.is_cuda:
204+
if not torch.cuda.is_available():
205+
raise RuntimeError("CUDA is not available")
206+
other = other.cuda()
207+
208+
output = torch.empty_like(input_tensor)
209+
a_ = from_dlpack(input_tensor)
210+
b_ = from_dlpack(other)
211+
c_ = from_dlpack(output)
212+
213+
add_tensor_kernel_launch_ = cute.compile(add_tensor_kernel_launch, a_, b_, c_)
214+
add_tensor_kernel_launch_(a_, b_, c_)
215+
else:
216+
# Tensor + Scalar case
217+
# Convert scalar to Python float
218+
if hasattr(other, 'item'):
219+
scalar_val = other.item()
220+
else:
221+
scalar_val = other
222+
223+
output = torch.empty_like(input_tensor)
224+
a_ = from_dlpack(input_tensor)
225+
c_ = from_dlpack(output)
226+
227+
add_scalar_kernel_launch_ = cute.compile(add_scalar_kernel_launch, a_, c_, scalar_val)
228+
add_scalar_kernel_launch_(a_, c_, scalar_val)
229+
230+
# Move result back to original device
231+
if original_device != output.device:
232+
output = output.to(original_device)
233+
234+
output = output.reshape(original_shape)
235+
236+
return output"""
237+
}

0 commit comments

Comments
 (0)