2929 replace_function_definitions_in_module ,
3030)
3131from codeflash .code_utils .code_utils import (
32+ choose_weights ,
3233 cleanup_paths ,
3334 create_rank_dictionary_compact ,
35+ create_score_dictionary_from_metrics ,
3436 diff_length ,
3537 extract_unique_errors ,
3638 file_name_from_test_module_name ,
3739 get_run_tmp_file ,
3840 module_name_from_file_path ,
41+ normalize_by_max ,
3942 restore_conftest ,
4043 unified_diff_strings ,
4144)
4548 N_CANDIDATES_EFFECTIVE ,
4649 N_CANDIDATES_LP_EFFECTIVE ,
4750 N_TESTS_TO_GENERATE_EFFECTIVE ,
51+ REFINE_ALL_THRESHOLD ,
52+ REFINED_CANDIDATE_RANKING_WEIGHTS ,
4853 REPEAT_OPTIMIZATION_PROBABILITY ,
54+ TOP_N_REFINEMENTS ,
4955 TOTAL_LOOPING_TIME_EFFECTIVE ,
5056)
5157from codeflash .code_utils .deduplicate_code import normalize_code
@@ -124,19 +130,23 @@ def __init__(
124130 self ,
125131 initial_candidates : list ,
126132 future_line_profile_results : concurrent .futures .Future ,
127- future_all_refinements : list ,
133+ all_refinements_data : list [AIServiceRefinerRequest ],
134+ ai_service_client : AiServiceClient ,
135+ executor : concurrent .futures .ThreadPoolExecutor ,
128136 ) -> None :
129137 self .candidate_queue = queue .Queue ()
130138 self .line_profiler_done = False
131139 self .refinement_done = False
132140 self .candidate_len = len (initial_candidates )
141+ self .ai_service_client = ai_service_client
142+ self .executor = executor
133143
134144 # Initialize queue with initial candidates
135145 for candidate in initial_candidates :
136146 self .candidate_queue .put (candidate )
137147
138148 self .future_line_profile_results = future_line_profile_results
139- self .future_all_refinements = future_all_refinements
149+ self .all_refinements_data = all_refinements_data
140150
141151 def get_next_candidate (self ) -> OptimizedCandidate | None :
142152 """Get the next candidate from the queue, handling async results as needed."""
@@ -168,15 +178,45 @@ def _process_line_profiler_results(self) -> OptimizedCandidate | None:
168178
169179 return self .get_next_candidate ()
170180
181+ def refine_optimizations (self , request : list [AIServiceRefinerRequest ]) -> concurrent .futures .Future :
182+ return self .executor .submit (self .ai_service_client .optimize_python_code_refinement , request = request )
183+
171184 def _process_refinement_results (self ) -> OptimizedCandidate | None :
172- """Process refinement results and add to queue."""
173- if self .future_all_refinements :
185+ """Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined."""
186+ future_refinements : list [concurrent .futures .Future ] = []
187+
188+ if len (self .all_refinements_data ) <= REFINE_ALL_THRESHOLD :
189+ for data in self .all_refinements_data :
190+ future_refinements .append (self .refine_optimizations ([data ])) # noqa: PERF401
191+ else :
192+ diff_lens_list = []
193+ runtimes_list = []
194+ for c in self .all_refinements_data :
195+ diff_lens_list .append (diff_length (c .original_source_code , c .optimized_source_code ))
196+ runtimes_list .append (c .optimized_code_runtime )
197+
198+ runtime_w , diff_w = REFINED_CANDIDATE_RANKING_WEIGHTS
199+ weights = choose_weights (runtime = runtime_w , diff = diff_w )
200+
201+ runtime_norm = normalize_by_max (runtimes_list )
202+ diffs_norm = normalize_by_max (diff_lens_list )
203+ # the lower the better
204+ score_dict = create_score_dictionary_from_metrics (weights , runtime_norm , diffs_norm )
205+ top_n_candidates = int ((TOP_N_REFINEMENTS * len (runtimes_list )) + 0.5 )
206+ top_indecies = sorted (score_dict , key = score_dict .get )[:top_n_candidates ]
207+
208+ for idx in top_indecies :
209+ data = self .all_refinements_data [idx ]
210+ future_refinements .append (self .refine_optimizations ([data ]))
211+
212+ if future_refinements :
174213 logger .info ("loading|Refining generated code for improved quality and performance..." )
175- concurrent .futures .wait (self .future_all_refinements )
214+
215+ concurrent .futures .wait (future_refinements )
176216 refinement_response = []
177217
178- for future_refinement in self . future_all_refinements :
179- possible_refinement = future_refinement .result ()
218+ for f in future_refinements :
219+ possible_refinement = f .result ()
180220 if len (possible_refinement ) > 0 :
181221 refinement_response .append (possible_refinement [0 ])
182222
@@ -686,15 +726,14 @@ def process_single_candidate(
686726 original_helper_code : dict [Path , str ],
687727 file_path_to_helper_classes : dict [Path , set [str ]],
688728 eval_ctx : CandidateEvaluationContext ,
689- future_all_refinements : list [concurrent .futures .Future ],
690- ai_service_client : AiServiceClient ,
729+ all_refinements_data : list [AIServiceRefinerRequest ],
691730 exp_type : str ,
692731 function_references : str ,
693732 ) -> BestOptimization | None :
694733 """Process a single optimization candidate.
695734
696735 Returns the BestOptimization if the candidate is successful, None otherwise.
697- Updates eval_ctx with results and may append to future_all_refinements .
736+ Updates eval_ctx with results and may append to all_refinements_data .
698737 """
699738 # Cleanup temp files
700739 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
@@ -789,14 +828,19 @@ def process_single_candidate(
789828
790829 # Queue refinement for non-refined candidates
791830 if not candidate .optimization_id .endswith ("refi" ):
792- future_all_refinements .append (
793- self .refine_optimizations (
794- valid_optimizations = [best_optimization ],
795- original_code_baseline = original_code_baseline ,
796- code_context = code_context ,
831+ all_refinements_data .append (
832+ AIServiceRefinerRequest (
833+ optimization_id = best_optimization .candidate .optimization_id ,
834+ original_source_code = code_context .read_writable_code .markdown ,
835+ read_only_dependency_code = code_context .read_only_context_code ,
836+ original_code_runtime = original_code_baseline .runtime ,
837+ optimized_source_code = best_optimization .candidate .source_code .markdown ,
838+ optimized_explanation = best_optimization .candidate .explanation ,
839+ optimized_code_runtime = best_optimization .runtime ,
840+ speedup = f"{ int (performance_gain (original_runtime_ns = original_code_baseline .runtime , optimized_runtime_ns = best_optimization .runtime ) * 100 )} %" ,
797841 trace_id = self .get_trace_id (exp_type ),
798- ai_service_client = ai_service_client ,
799- executor = self . executor ,
842+ original_line_profiler_results = original_code_baseline . line_profile_results [ "str_out" ] ,
843+ optimized_line_profiler_results = best_optimization . line_profiler_test_results [ "str_out" ] ,
800844 function_references = function_references ,
801845 )
802846 )
@@ -832,7 +876,7 @@ def determine_best_candidate(
832876
833877 # Initialize evaluation context and async tasks
834878 eval_ctx = CandidateEvaluationContext ()
835- future_all_refinements : list [concurrent . futures . Future ] = []
879+ all_refinements_data : list [AIServiceRefinerRequest ] = []
836880 ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
837881 assert ai_service_client is not None , "AI service client must be set for optimization"
838882
@@ -850,7 +894,9 @@ def determine_best_candidate(
850894 else None ,
851895 )
852896
853- processor = CandidateProcessor (candidates , future_line_profile_results , future_all_refinements )
897+ processor = CandidateProcessor (
898+ candidates , future_line_profile_results , all_refinements_data , self .aiservice_client , self .executor
899+ )
854900 candidate_index = 0
855901
856902 # Process candidates using queue-based approach
@@ -871,8 +917,7 @@ def determine_best_candidate(
871917 original_helper_code = original_helper_code ,
872918 file_path_to_helper_classes = file_path_to_helper_classes ,
873919 eval_ctx = eval_ctx ,
874- future_all_refinements = future_all_refinements ,
875- ai_service_client = ai_service_client ,
920+ all_refinements_data = all_refinements_data ,
876921 exp_type = exp_type ,
877922 function_references = function_references ,
878923 )
@@ -905,35 +950,6 @@ def determine_best_candidate(
905950
906951 return best_optimization
907952
908- def refine_optimizations (
909- self ,
910- valid_optimizations : list [BestOptimization ],
911- original_code_baseline : OriginalCodeBaseline ,
912- code_context : CodeOptimizationContext ,
913- trace_id : str ,
914- ai_service_client : AiServiceClient ,
915- executor : concurrent .futures .ThreadPoolExecutor ,
916- function_references : str | None = None ,
917- ) -> concurrent .futures .Future :
918- request = [
919- AIServiceRefinerRequest (
920- optimization_id = opt .candidate .optimization_id ,
921- original_source_code = code_context .read_writable_code .markdown ,
922- read_only_dependency_code = code_context .read_only_context_code ,
923- original_code_runtime = humanize_runtime (original_code_baseline .runtime ),
924- optimized_source_code = opt .candidate .source_code .markdown ,
925- optimized_explanation = opt .candidate .explanation ,
926- optimized_code_runtime = humanize_runtime (opt .runtime ),
927- speedup = f"{ int (performance_gain (original_runtime_ns = original_code_baseline .runtime , optimized_runtime_ns = opt .runtime ) * 100 )} %" ,
928- trace_id = trace_id ,
929- original_line_profiler_results = original_code_baseline .line_profile_results ["str_out" ],
930- optimized_line_profiler_results = opt .line_profiler_test_results ["str_out" ],
931- function_references = function_references ,
932- )
933- for opt in valid_optimizations
934- ]
935- return executor .submit (ai_service_client .optimize_python_code_refinement , request = request )
936-
937953 def log_successful_optimization (
938954 self , explanation : Explanation , generated_tests : GeneratedTestsList , exp_type : str
939955 ) -> None :
0 commit comments