Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions BackendBench/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import anthropic
import requests
import torch
from tenacity import retry
from tenacity.wait import wait_random_exponential
from transformers import AutoModelForCausalLM, AutoTokenizer

from .kernel_templates import KernelTemplateManager

Expand Down Expand Up @@ -128,6 +130,96 @@ def _extract_code_from_response(self, response: str) -> str:
return response[start:end].strip()


class HuggingFaceKernelGenerator(LLMKernelGenerator):
"""
LLM Kernel Generator that uses local HuggingFace model.
"""

def __init__(
self,
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
device: str = "cuda",
max_new_tokens: int = 2048,
hf_token: Optional[str] = None,
):
self.model_name = model_name
self.device = device
self.max_new_tokens = max_new_tokens
self.template_manager = KernelTemplateManager()
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=hf_token)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto").to(
self.device
)
self.model.eval()

@property
def readme_server_description(self) -> str:
return f"Local HuggingFace model ({self.model_name})"

@property
def readme_setup_section(self) -> str:
return """## Setup
This backend uses a local HuggingFace model and requires:
```bash
pip install torch torchvision torchaudio transformers
```
"""

def generate_kernel(
self,
op_name: str,
op_signature: str,
op_description: str,
framework: str = "triton",
feedback: Optional[str] = None,
) -> str:
if feedback:
prompt = self.template_manager.create_refinement_prompt(
op_name, op_signature, op_description, framework, feedback
)
else:
prompt = self.template_manager.create_prompt(
op_name, op_signature, op_description, framework
)

print("\n=== DEBUG: PROMPT SENT TO LOCAL LLM ===")
print(prompt)
print("=== END PROMPT ===\n")

try:
content = self.call_llm(prompt)
if not content:
raise RuntimeError("Empty response from local LLM")

extracted_code = self._extract_code_from_response(content)

print("\n=== DEBUG: RAW LOCAL LLM RESPONSE ===")
print(content)
print("=== DEBUG: EXTRACTED CODE ===")
print(extracted_code)
print("=== END DEBUG ===\n")

return extracted_code

except Exception as e:
raise RuntimeError(f"Failed to generate kernel for {op_name}: {str(e)}")

@retry(wait=wait_random_exponential(multiplier=2, min=1, max=60, exp_base=2))
def call_llm(self, prompt: str) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=True,
top_p=0.95,
temperature=0.7,
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the part of the response that comes after the prompt
return response[len(prompt) :].strip()


class LLMRelayKernelGenerator(LLMKernelGenerator):
"""
LLM Kernel Generator that uses local plugboard server.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"pandas",
"datasets",
"tenacity",
"transformers>=4.56.2",
]

[project.optional-dependencies]
Expand Down