Skip to content

Commit 92870b2

Browse files
authored
Merge pull request #945 from codeflash-ai/feat/feedback-loop-for-unmatched-test-results
[FEAT] Code-repair for candidates with unmatched test results
2 parents 09fa96c + 194ded5 commit 92870b2

14 files changed

+911
-88
lines changed

codeflash/api/aiservice.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
from codeflash.code_utils.time_utils import humanize_runtime
1919
from codeflash.lsp.helpers import is_LSP_enabled
2020
from codeflash.models.ExperimentMetadata import ExperimentMetadata
21-
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
21+
from codeflash.models.models import (
22+
AIServiceRefinerRequest,
23+
CodeStringsMarkdown,
24+
OptimizedCandidate,
25+
OptimizedCandidateSource,
26+
)
2227
from codeflash.telemetry.posthog_cf import ph
2328
from codeflash.version import __version__ as codeflash_version
2429

@@ -27,7 +32,7 @@
2732

2833
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
2934
from codeflash.models.ExperimentMetadata import ExperimentMetadata
30-
from codeflash.models.models import AIServiceRefinerRequest
35+
from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest
3136
from codeflash.result.explanation import Explanation
3237

3338

@@ -86,15 +91,21 @@ def make_ai_service_request(
8691
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8792
return response
8893

89-
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
94+
def _get_valid_candidates(
95+
self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource
96+
) -> list[OptimizedCandidate]:
9097
candidates: list[OptimizedCandidate] = []
9198
for opt in optimizations_json:
9299
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
93100
if not code.code_strings:
94101
continue
95102
candidates.append(
96103
OptimizedCandidate(
97-
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
104+
source_code=code,
105+
explanation=opt["explanation"],
106+
optimization_id=opt["optimization_id"],
107+
source=source,
108+
parent_id=opt.get("parent_id", None),
98109
)
99110
)
100111
return candidates
@@ -156,7 +167,7 @@ def optimize_python_code( # noqa: D417
156167
console.rule()
157168
end_time = time.perf_counter()
158169
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")
159-
return self._get_valid_candidates(optimizations_json)
170+
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE)
160171
try:
161172
error = response.json()["error"]
162173
except Exception:
@@ -221,7 +232,7 @@ def optimize_python_code_line_profiler( # noqa: D417
221232
f"!lsp|Generated {len(optimizations_json)} candidate optimizations using line profiler information."
222233
)
223234
console.rule()
224-
return self._get_valid_candidates(optimizations_json)
235+
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP)
225236
try:
226237
error = response.json()["error"]
227238
except Exception:
@@ -270,15 +281,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
270281
if response.status_code == 200:
271282
refined_optimizations = response.json()["refinements"]
272283

273-
refinements = self._get_valid_candidates(refined_optimizations)
274-
return [
275-
OptimizedCandidate(
276-
source_code=c.source_code,
277-
explanation=c.explanation,
278-
optimization_id=c.optimization_id[:-4] + "refi",
279-
)
280-
for c in refinements
281-
]
284+
return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE)
282285

283286
try:
284287
error = response.json()["error"]
@@ -289,6 +292,52 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
289292
console.rule()
290293
return []
291294

295+
def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None:
296+
"""Repair the optimization candidate that is not matching the test result of the original code.
297+
298+
Args:
299+
request: candidate details for repair
300+
301+
Returns:
302+
-------
303+
- OptimizedCandidate: new fixed candidate.
304+
305+
"""
306+
console.rule()
307+
try:
308+
payload = {
309+
"optimization_id": request.optimization_id,
310+
"original_source_code": request.original_source_code,
311+
"modified_source_code": request.modified_source_code,
312+
"trace_id": request.trace_id,
313+
"test_diffs": request.test_diffs,
314+
}
315+
response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120)
316+
except (requests.exceptions.RequestException, TypeError) as e:
317+
logger.exception(f"Error generating optimization repair: {e}")
318+
ph("cli-optimize-error-caught", {"error": str(e)})
319+
return None
320+
321+
if response.status_code == 200:
322+
fixed_optimization = response.json()
323+
console.rule()
324+
325+
valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.REPAIR)
326+
if not valid_candidates:
327+
logger.error("Code repair failed to generate a valid candidate.")
328+
return None
329+
330+
return valid_candidates[0]
331+
332+
try:
333+
error = response.json()["error"]
334+
except Exception:
335+
error = response.text
336+
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
337+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
338+
console.rule()
339+
return None
340+
292341
def get_new_explanation( # noqa: D417
293342
self,
294343
source_code: str,

codeflash/code_utils/code_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
423423
sys.exit(1 if error_on_exit else 0)
424424

425425

426+
def shorten_pytest_error(pytest_error_string: str) -> str:
427+
return "\n".join(re.findall(r"^[E>] +(.*)$", pytest_error_string, re.MULTILINE))
428+
429+
426430
def extract_unique_errors(pytest_output: str) -> set[str]:
427431
unique_errors = set()
428432

codeflash/code_utils/config_consts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
TOTAL_LOOPING_TIME_LSP = 10.0 # Kept same timing for LSP mode to avoid in increase in performance reporting
2626
N_CANDIDATES_LP_LSP = 3
2727

28+
# Code repair
29+
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.4 # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted)
30+
MAX_REPAIRS_PER_TRACE = 4 # maximum number of repairs we will do for each function
31+
2832
MAX_N_CANDIDATES = 5
2933
MAX_N_CANDIDATES_LP = 6
3034

codeflash/models/models.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import Counter, defaultdict
44
from typing import TYPE_CHECKING
55

6+
import libcst as cst
67
from rich.tree import Tree
78

89
from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log
@@ -47,6 +48,34 @@ class AIServiceRefinerRequest:
4748
function_references: str | None = None
4849

4950

51+
class TestDiffScope(str, Enum):
52+
RETURN_VALUE = "return_value"
53+
STDOUT = "stdout"
54+
DID_PASS = "did_pass" # noqa: S105
55+
56+
57+
@dataclass
58+
class TestDiff:
59+
scope: TestDiffScope
60+
original_pass: bool
61+
candidate_pass: bool
62+
63+
original_value: str | None = None
64+
candidate_value: str | None = None
65+
test_src_code: Optional[str] = None
66+
candidate_pytest_error: Optional[str] = None
67+
original_pytest_error: Optional[str] = None
68+
69+
70+
@dataclass(frozen=True)
71+
class AIServiceCodeRepairRequest:
72+
optimization_id: str
73+
original_source_code: str
74+
modified_source_code: str
75+
trace_id: str
76+
test_diffs: list[TestDiff]
77+
78+
5079
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
5180
# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
5281
# of the module is foo.eggs.
@@ -243,12 +272,12 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
243272
244273
"""
245274
matches = markdown_pattern.findall(markdown_code)
246-
results = CodeStringsMarkdown()
275+
code_string_list = []
247276
try:
248277
for file_path, code in matches:
249278
path = file_path.strip()
250-
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
251-
return results # noqa: TRY300
279+
code_string_list.append(CodeString(code=code, file_path=Path(path)))
280+
return CodeStringsMarkdown(code_strings=code_string_list)
252281
except ValidationError:
253282
# if any file is invalid, return an empty CodeStringsMarkdown for the entire context
254283
return CodeStringsMarkdown()
@@ -421,11 +450,20 @@ class TestsInFile:
421450
test_type: TestType
422451

423452

453+
class OptimizedCandidateSource(str, Enum):
454+
OPTIMIZE = "OPTIMIZE"
455+
OPTIMIZE_LP = "OPTIMIZE_LP"
456+
REFINE = "REFINE"
457+
REPAIR = "REPAIR"
458+
459+
424460
@dataclass(frozen=True)
425461
class OptimizedCandidate:
426462
source_code: CodeStringsMarkdown
427463
explanation: str
428464
optimization_id: str
465+
source: OptimizedCandidateSource
466+
parent_id: str | None = None
429467

430468

431469
@dataclass(frozen=True)
@@ -572,6 +610,42 @@ def id(self) -> str:
572610
f"{self.function_getting_tested}:{self.iteration_id}"
573611
)
574612

613+
# TestSuiteClass.test_function_name
614+
def test_fn_qualified_name(self) -> str:
615+
# Use f-string with inline conditional to reduce string concatenation operations
616+
return (
617+
f"{self.test_class_name}.{self.test_function_name}"
618+
if self.test_class_name
619+
else str(self.test_function_name)
620+
)
621+
622+
def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]:
623+
for stmt in class_node.body.body:
624+
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name:
625+
return stmt
626+
return None
627+
628+
def get_src_code(self, test_path: Path) -> Optional[str]:
629+
if not test_path.exists():
630+
return None
631+
test_src = test_path.read_text(encoding="utf-8")
632+
module_node = cst.parse_module(test_src)
633+
634+
if self.test_class_name:
635+
for stmt in module_node.body:
636+
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
637+
func_node = self.find_func_in_class(stmt, self.test_function_name)
638+
if func_node:
639+
return module_node.code_for_node(func_node).strip()
640+
# class not found
641+
return None
642+
643+
# Otherwise, look for a top level function
644+
for stmt in module_node.body:
645+
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name:
646+
return module_node.code_for_node(stmt).strip()
647+
return None
648+
575649
@staticmethod
576650
def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId:
577651
components = string_id.split(":")
@@ -616,7 +690,10 @@ class TestResults(BaseModel): # noqa: PLW1641
616690
# also we don't support deletion of test results elements - caution is advised
617691
test_results: list[FunctionTestInvocation] = []
618692
test_result_idx: dict[str, int] = {}
693+
619694
perf_stdout: Optional[str] = None
695+
# mapping between test function name and stdout failure message
696+
test_failures: Optional[dict[str, str]] = None
620697

621698
def add(self, function_test_invocation: FunctionTestInvocation) -> None:
622699
unique_id = function_test_invocation.unique_invocation_loop_id

0 commit comments

Comments
 (0)