Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions BackendBench/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
from typing import Optional

Expand Down Expand Up @@ -77,6 +78,7 @@ def generate_kernel(
op_description: str,
framework: str = "triton",
feedback: Optional[str] = None,
correctness_results: Optional[str] = None,
) -> str:
if feedback:
prompt = self.template_manager.create_refinement_prompt(
Expand Down Expand Up @@ -104,6 +106,20 @@ def generate_kernel(
print(extracted_code)
print("=== END DEBUG ===\n")

logger = logging.getLogger(__name__)
if correctness_results is not None:
passed = sum(1 for r in correctness_results if getattr(r, "is_correct", False))
total = len(correctness_results)
logger.info(f"Correctness results for {op_name}: {passed}/{total} tests passed.")
for i, result in enumerate(correctness_results):
logger.info(
f"Test {i + 1}: {'Passed' if getattr(result, 'is_correct', False) else 'Failed'}"
f" args={getattr(result, 'args', None)}, "
f"max_abs_error={getattr(result, 'max_abs_error', None)}, "
f"max_rel_error={getattr(result, 'max_rel_error', None)}, "
f"error_msg={getattr(result, 'error_msg', None)}"
)

return extracted_code

except requests.exceptions.RequestException as e:
Expand Down