1818from codeflash .code_utils .time_utils import humanize_runtime
1919from codeflash .lsp .helpers import is_LSP_enabled
2020from codeflash .models .ExperimentMetadata import ExperimentMetadata
21- from codeflash .models .models import AIServiceRefinerRequest , CodeStringsMarkdown , OptimizedCandidate
21+ from codeflash .models .models import (
22+ AIServiceRefinerRequest ,
23+ CodeStringsMarkdown ,
24+ OptimizedCandidate ,
25+ OptimizedCandidateSource ,
26+ )
2227from codeflash .telemetry .posthog_cf import ph
2328from codeflash .version import __version__ as codeflash_version
2429
2732
2833 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
2934 from codeflash .models .ExperimentMetadata import ExperimentMetadata
30- from codeflash .models .models import AIServiceRefinerRequest
35+ from codeflash .models .models import AIServiceCodeRepairRequest , AIServiceRefinerRequest
3136 from codeflash .result .explanation import Explanation
3237
3338
@@ -86,15 +91,21 @@ def make_ai_service_request(
8691 # response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8792 return response
8893
89- def _get_valid_candidates (self , optimizations_json : list [dict [str , Any ]]) -> list [OptimizedCandidate ]:
94+ def _get_valid_candidates (
95+ self , optimizations_json : list [dict [str , Any ]], source : OptimizedCandidateSource
96+ ) -> list [OptimizedCandidate ]:
9097 candidates : list [OptimizedCandidate ] = []
9198 for opt in optimizations_json :
9299 code = CodeStringsMarkdown .parse_markdown_code (opt ["source_code" ])
93100 if not code .code_strings :
94101 continue
95102 candidates .append (
96103 OptimizedCandidate (
97- source_code = code , explanation = opt ["explanation" ], optimization_id = opt ["optimization_id" ]
104+ source_code = code ,
105+ explanation = opt ["explanation" ],
106+ optimization_id = opt ["optimization_id" ],
107+ source = source ,
108+ parent_id = opt .get ("parent_id" , None ),
98109 )
99110 )
100111 return candidates
@@ -156,7 +167,7 @@ def optimize_python_code( # noqa: D417
156167 console .rule ()
157168 end_time = time .perf_counter ()
158169 logger .debug (f"!lsp|Generating possible optimizations took { end_time - start_time :.2f} seconds." )
159- return self ._get_valid_candidates (optimizations_json )
170+ return self ._get_valid_candidates (optimizations_json , OptimizedCandidateSource . OPTIMIZE )
160171 try :
161172 error = response .json ()["error" ]
162173 except Exception :
@@ -221,7 +232,7 @@ def optimize_python_code_line_profiler( # noqa: D417
221232 f"!lsp|Generated { len (optimizations_json )} candidate optimizations using line profiler information."
222233 )
223234 console .rule ()
224- return self ._get_valid_candidates (optimizations_json )
235+ return self ._get_valid_candidates (optimizations_json , OptimizedCandidateSource . OPTIMIZE_LP )
225236 try :
226237 error = response .json ()["error" ]
227238 except Exception :
@@ -270,15 +281,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
270281 if response .status_code == 200 :
271282 refined_optimizations = response .json ()["refinements" ]
272283
273- refinements = self ._get_valid_candidates (refined_optimizations )
274- return [
275- OptimizedCandidate (
276- source_code = c .source_code ,
277- explanation = c .explanation ,
278- optimization_id = c .optimization_id [:- 4 ] + "refi" ,
279- )
280- for c in refinements
281- ]
284+ return self ._get_valid_candidates (refined_optimizations , OptimizedCandidateSource .REFINE )
282285
283286 try :
284287 error = response .json ()["error" ]
@@ -289,6 +292,52 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
289292 console .rule ()
290293 return []
291294
295+ def code_repair (self , request : AIServiceCodeRepairRequest ) -> OptimizedCandidate | None :
296+ """Repair the optimization candidate that is not matching the test result of the original code.
297+
298+ Args:
299+ request: candidate details for repair
300+
301+ Returns:
302+ -------
303+ - OptimizedCandidate: new fixed candidate.
304+
305+ """
306+ console .rule ()
307+ try :
308+ payload = {
309+ "optimization_id" : request .optimization_id ,
310+ "original_source_code" : request .original_source_code ,
311+ "modified_source_code" : request .modified_source_code ,
312+ "trace_id" : request .trace_id ,
313+ "test_diffs" : request .test_diffs ,
314+ }
315+ response = self .make_ai_service_request ("/code_repair" , payload = payload , timeout = 120 )
316+ except (requests .exceptions .RequestException , TypeError ) as e :
317+ logger .exception (f"Error generating optimization repair: { e } " )
318+ ph ("cli-optimize-error-caught" , {"error" : str (e )})
319+ return None
320+
321+ if response .status_code == 200 :
322+ fixed_optimization = response .json ()
323+ console .rule ()
324+
325+ valid_candidates = self ._get_valid_candidates ([fixed_optimization ], OptimizedCandidateSource .REPAIR )
326+ if not valid_candidates :
327+ logger .error ("Code repair failed to generate a valid candidate." )
328+ return None
329+
330+ return valid_candidates [0 ]
331+
332+ try :
333+ error = response .json ()["error" ]
334+ except Exception :
335+ error = response .text
336+ logger .error (f"Error generating optimized candidates: { response .status_code } - { error } " )
337+ ph ("cli-optimize-error-response" , {"response_status_code" : response .status_code , "error" : error })
338+ console .rule ()
339+ return None
340+
292341 def get_new_explanation ( # noqa: D417
293342 self ,
294343 source_code : str ,
0 commit comments