diff --git a/BackendBench/llm_client.py b/BackendBench/llm_client.py index 35e575c5..e8ccda24 100644 --- a/BackendBench/llm_client.py +++ b/BackendBench/llm_client.py @@ -9,8 +9,10 @@ import anthropic import requests +import torch from tenacity import retry from tenacity.wait import wait_random_exponential +from transformers import AutoModelForCausalLM, AutoTokenizer from .kernel_templates import KernelTemplateManager @@ -128,6 +130,96 @@ def _extract_code_from_response(self, response: str) -> str: return response[start:end].strip() +class HuggingFaceKernelGenerator(LLMKernelGenerator): + """ + LLM Kernel Generator that uses local HuggingFace model. + """ + + def __init__( + self, + model_name: str = "meta-llama/Llama-2-7b-chat-hf", + device: str = "cuda", + max_new_tokens: int = 2048, + hf_token: Optional[str] = None, + ): + self.model_name = model_name + self.device = device + self.max_new_tokens = max_new_tokens + self.template_manager = KernelTemplateManager() + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=hf_token) + self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto").to( + self.device + ) + self.model.eval() + + @property + def readme_server_description(self) -> str: + return f"Local HuggingFace model ({self.model_name})" + + @property + def readme_setup_section(self) -> str: + return """## Setup +This backend uses a local HuggingFace model and requires: +```bash +pip install torch torchvision torchaudio transformers +``` +""" + + def generate_kernel( + self, + op_name: str, + op_signature: str, + op_description: str, + framework: str = "triton", + feedback: Optional[str] = None, + ) -> str: + if feedback: + prompt = self.template_manager.create_refinement_prompt( + op_name, op_signature, op_description, framework, feedback + ) + else: + prompt = self.template_manager.create_prompt( + op_name, op_signature, op_description, framework + ) + + print("\n=== DEBUG: PROMPT SENT TO LOCAL LLM ===") + print(prompt) + print("=== END PROMPT ===\n") + + try: + content = self.call_llm(prompt) + if not content: + raise RuntimeError("Empty response from local LLM") + + extracted_code = self._extract_code_from_response(content) + + print("\n=== DEBUG: RAW LOCAL LLM RESPONSE ===") + print(content) + print("=== DEBUG: EXTRACTED CODE ===") + print(extracted_code) + print("=== END DEBUG ===\n") + + return extracted_code + + except Exception as e: + raise RuntimeError(f"Failed to generate kernel for {op_name}: {str(e)}") + + @retry(wait=wait_random_exponential(multiplier=2, min=1, max=60, exp_base=2)) + def call_llm(self, prompt: str) -> str: + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=True, + top_p=0.95, + temperature=0.7, + ) + response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + # Extract the part of the response that comes after the prompt + return response[len(prompt) :].strip() + + class LLMRelayKernelGenerator(LLMKernelGenerator): """ LLM Kernel Generator that uses local plugboard server. diff --git a/pyproject.toml b/pyproject.toml index ae6bafcc..551bf216 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "pandas", "datasets", "tenacity", + "transformers>=4.56.2", ] [project.optional-dependencies]