From 0dc1da2b1bce3615068f5f23d4b7b04ba2b207ce Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 2 Oct 2025 20:04:26 +0530 Subject: [PATCH] integrated correctness results with logging --- BackendBench/llm_client.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/BackendBench/llm_client.py b/BackendBench/llm_client.py index 35e575c5..01ea42e3 100644 --- a/BackendBench/llm_client.py +++ b/BackendBench/llm_client.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import logging import os from typing import Optional @@ -77,6 +78,7 @@ def generate_kernel( op_description: str, framework: str = "triton", feedback: Optional[str] = None, + correctness_results: Optional[str] = None, ) -> str: if feedback: prompt = self.template_manager.create_refinement_prompt( @@ -104,6 +106,20 @@ def generate_kernel( print(extracted_code) print("=== END DEBUG ===\n") + logger = logging.getLogger(__name__) + if correctness_results is not None: + passed = sum(1 for r in correctness_results if getattr(r, "is_correct", False)) + total = len(correctness_results) + logger.info(f"Correctness results for {op_name}: {passed}/{total} tests passed.") + for i, result in enumerate(correctness_results): + logger.info( + f"Test {i + 1}: {'Passed' if getattr(result, 'is_correct', False) else 'Failed'}" + f" args={getattr(result, 'args', None)}, " + f"max_abs_error={getattr(result, 'max_abs_error', None)}, " + f"max_rel_error={getattr(result, 'max_rel_error', None)}, " + f"error_msg={getattr(result, 'error_msg', None)}" + ) + return extracted_code except requests.exceptions.RequestException as e: