2929 replace_function_definitions_in_module ,
3030)
3131from codeflash .code_utils .code_utils import (
32+ choose_weights ,
3233 cleanup_paths ,
3334 create_rank_dictionary_compact ,
35+ create_score_dictionary_from_metrics ,
3436 diff_length ,
3537 extract_unique_errors ,
3638 file_name_from_test_module_name ,
3739 get_run_tmp_file ,
3840 module_name_from_file_path ,
41+ normalize ,
3942 restore_conftest ,
4043 unified_diff_strings ,
4144)
4548 N_CANDIDATES_EFFECTIVE ,
4649 N_CANDIDATES_LP_EFFECTIVE ,
4750 N_TESTS_TO_GENERATE_EFFECTIVE ,
51+ REFINE_ALL_THRESHOLD ,
52+ REFINED_CANDIDATE_RANKING_WEIGHTS ,
4853 REPEAT_OPTIMIZATION_PROBABILITY ,
54+ TOP_N_REFINEMENTS ,
4955 TOTAL_LOOPING_TIME_EFFECTIVE ,
5056)
5157from codeflash .code_utils .deduplicate_code import normalize_code
@@ -124,19 +130,23 @@ def __init__(
124130 self ,
125131 initial_candidates : list ,
126132 future_line_profile_results : concurrent .futures .Future ,
127- future_all_refinements : list ,
133+ all_refinements_data : list [AIServiceRefinerRequest ],
134+ ai_service_client : AiServiceClient ,
135+ executor : concurrent .futures .ThreadPoolExecutor ,
128136 ) -> None :
129137 self .candidate_queue = queue .Queue ()
130138 self .line_profiler_done = False
131139 self .refinement_done = False
132140 self .candidate_len = len (initial_candidates )
141+ self .ai_service_client = ai_service_client
142+ self .executor = executor
133143
134144 # Initialize queue with initial candidates
135145 for candidate in initial_candidates :
136146 self .candidate_queue .put (candidate )
137147
138148 self .future_line_profile_results = future_line_profile_results
139- self .future_all_refinements = future_all_refinements
149+ self .all_refinements_data = all_refinements_data
140150
141151 def get_next_candidate (self ) -> OptimizedCandidate | None :
142152 """Get the next candidate from the queue, handling async results as needed."""
@@ -168,15 +178,45 @@ def _process_line_profiler_results(self) -> OptimizedCandidate | None:
168178
169179 return self .get_next_candidate ()
170180
181+ def refine_optimizations (self , request : list [AIServiceRefinerRequest ]) -> concurrent .futures .Future :
182+ return self .executor .submit (self .ai_service_client .optimize_python_code_refinement , request = request )
183+
171184 def _process_refinement_results (self ) -> OptimizedCandidate | None :
172- """Process refinement results and add to queue."""
173- if self .future_all_refinements :
185+ """Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined."""
186+ future_refinements : list [concurrent .futures .Future ] = []
187+
188+ if len (self .all_refinements_data ) <= REFINE_ALL_THRESHOLD :
189+ for data in self .all_refinements_data :
190+ future_refinements .append (self .refine_optimizations ([data ])) # noqa: PERF401
191+ else :
192+ diff_lens_list = []
193+ runtimes_list = []
194+ for c in self .all_refinements_data :
195+ diff_lens_list .append (diff_length (c .original_source_code , c .optimized_source_code ))
196+ runtimes_list .append (c .optimized_code_runtime )
197+
198+ runtime_w , diff_w = REFINED_CANDIDATE_RANKING_WEIGHTS
199+ weights = choose_weights (runtime = runtime_w , diff = diff_w )
200+
201+ runtime_norm = normalize (runtimes_list )
202+ diffs_norm = normalize (diff_lens_list )
203+ # the lower the better
204+ score_dict = create_score_dictionary_from_metrics (weights , runtime_norm , diffs_norm )
205+ top_n_candidates = int ((TOP_N_REFINEMENTS * len (runtimes_list )) + 0.5 )
206+ top_indecies = sorted (score_dict , key = score_dict .get )[:top_n_candidates ]
207+
208+ for idx in top_indecies :
209+ data = self .all_refinements_data [idx ]
210+ future_refinements .append (self .refine_optimizations ([data ]))
211+
212+ if future_refinements :
174213 logger .info ("loading|Refining generated code for improved quality and performance..." )
175- concurrent .futures .wait (self .future_all_refinements )
214+
215+ concurrent .futures .wait (future_refinements )
176216 refinement_response = []
177217
178- for future_refinement in self . future_all_refinements :
179- possible_refinement = future_refinement .result ()
218+ for f in future_refinements :
219+ possible_refinement = f .result ()
180220 if len (possible_refinement ) > 0 :
181221 refinement_response .append (possible_refinement [0 ])
182222
@@ -684,15 +724,14 @@ def process_single_candidate(
684724 original_helper_code : dict [Path , str ],
685725 file_path_to_helper_classes : dict [Path , set [str ]],
686726 eval_ctx : CandidateEvaluationContext ,
687- future_all_refinements : list [concurrent .futures .Future ],
688- ai_service_client : AiServiceClient ,
727+ all_refinements_data : list [AIServiceRefinerRequest ],
689728 exp_type : str ,
690729 function_references : str ,
691730 ) -> BestOptimization | None :
692731 """Process a single optimization candidate.
693732
694733 Returns the BestOptimization if the candidate is successful, None otherwise.
695- Updates eval_ctx with results and may append to future_all_refinements .
734+ Updates eval_ctx with results and may append to all_refinements_data .
696735 """
697736 # Cleanup temp files
698737 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
@@ -787,14 +826,19 @@ def process_single_candidate(
787826
788827 # Queue refinement for non-refined candidates
789828 if not candidate .optimization_id .endswith ("refi" ):
790- future_all_refinements .append (
791- self .refine_optimizations (
792- valid_optimizations = [best_optimization ],
793- original_code_baseline = original_code_baseline ,
794- code_context = code_context ,
829+ all_refinements_data .append (
830+ AIServiceRefinerRequest (
831+ optimization_id = best_optimization .candidate .optimization_id ,
832+ original_source_code = code_context .read_writable_code .markdown ,
833+ read_only_dependency_code = code_context .read_only_context_code ,
834+ original_code_runtime = original_code_baseline .runtime ,
835+ optimized_source_code = best_optimization .candidate .source_code .markdown ,
836+ optimized_explanation = best_optimization .candidate .explanation ,
837+ optimized_code_runtime = best_optimization .runtime ,
838+ speedup = f"{ int (performance_gain (original_runtime_ns = original_code_baseline .runtime , optimized_runtime_ns = best_optimization .runtime ) * 100 )} %" ,
795839 trace_id = self .get_trace_id (exp_type ),
796- ai_service_client = ai_service_client ,
797- executor = self . executor ,
840+ original_line_profiler_results = original_code_baseline . line_profile_results [ "str_out" ] ,
841+ optimized_line_profiler_results = best_optimization . line_profiler_test_results [ "str_out" ] ,
798842 function_references = function_references ,
799843 )
800844 )
@@ -830,7 +874,7 @@ def determine_best_candidate(
830874
831875 # Initialize evaluation context and async tasks
832876 eval_ctx = CandidateEvaluationContext ()
833- future_all_refinements : list [concurrent . futures . Future ] = []
877+ all_refinements_data : list [AIServiceRefinerRequest ] = []
834878 ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
835879 assert ai_service_client is not None , "AI service client must be set for optimization"
836880
@@ -848,7 +892,9 @@ def determine_best_candidate(
848892 else None ,
849893 )
850894
851- processor = CandidateProcessor (candidates , future_line_profile_results , future_all_refinements )
895+ processor = CandidateProcessor (
896+ candidates , future_line_profile_results , all_refinements_data , self .aiservice_client , self .executor
897+ )
852898 candidate_index = 0
853899
854900 # Process candidates using queue-based approach
@@ -869,8 +915,7 @@ def determine_best_candidate(
869915 original_helper_code = original_helper_code ,
870916 file_path_to_helper_classes = file_path_to_helper_classes ,
871917 eval_ctx = eval_ctx ,
872- future_all_refinements = future_all_refinements ,
873- ai_service_client = ai_service_client ,
918+ all_refinements_data = all_refinements_data ,
874919 exp_type = exp_type ,
875920 function_references = function_references ,
876921 )
@@ -903,35 +948,6 @@ def determine_best_candidate(
903948
904949 return best_optimization
905950
906- def refine_optimizations (
907- self ,
908- valid_optimizations : list [BestOptimization ],
909- original_code_baseline : OriginalCodeBaseline ,
910- code_context : CodeOptimizationContext ,
911- trace_id : str ,
912- ai_service_client : AiServiceClient ,
913- executor : concurrent .futures .ThreadPoolExecutor ,
914- function_references : str | None = None ,
915- ) -> concurrent .futures .Future :
916- request = [
917- AIServiceRefinerRequest (
918- optimization_id = opt .candidate .optimization_id ,
919- original_source_code = code_context .read_writable_code .markdown ,
920- read_only_dependency_code = code_context .read_only_context_code ,
921- original_code_runtime = humanize_runtime (original_code_baseline .runtime ),
922- optimized_source_code = opt .candidate .source_code .markdown ,
923- optimized_explanation = opt .candidate .explanation ,
924- optimized_code_runtime = humanize_runtime (opt .runtime ),
925- speedup = f"{ int (performance_gain (original_runtime_ns = original_code_baseline .runtime , optimized_runtime_ns = opt .runtime ) * 100 )} %" ,
926- trace_id = trace_id ,
927- original_line_profiler_results = original_code_baseline .line_profile_results ["str_out" ],
928- optimized_line_profiler_results = opt .line_profiler_test_results ["str_out" ],
929- function_references = function_references ,
930- )
931- for opt in valid_optimizations
932- ]
933- return executor .submit (ai_service_client .optimize_python_code_refinement , request = request )
934-
935951 def log_successful_optimization (
936952 self , explanation : Explanation , generated_tests : GeneratedTestsList , exp_type : str
937953 ) -> None :
0 commit comments