Skip to content

Commit e895c52

Browse files
limit the refined candidates based on the weighted ranking
1 parent a879f11 commit e895c52

File tree

5 files changed

+132
-58
lines changed

5 files changed

+132
-58
lines changed

codeflash/api/aiservice.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,20 +248,18 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
248248
"original_source_code": opt.original_source_code,
249249
"read_only_dependency_code": opt.read_only_dependency_code,
250250
"original_line_profiler_results": opt.original_line_profiler_results,
251-
"original_code_runtime": opt.original_code_runtime,
251+
"original_code_runtime": humanize_runtime(opt.original_code_runtime),
252252
"optimized_source_code": opt.optimized_source_code,
253253
"optimized_explanation": opt.optimized_explanation,
254254
"optimized_line_profiler_results": opt.optimized_line_profiler_results,
255-
"optimized_code_runtime": opt.optimized_code_runtime,
255+
"optimized_code_runtime": humanize_runtime(opt.optimized_code_runtime),
256256
"speedup": opt.speedup,
257257
"trace_id": opt.trace_id,
258258
"function_references": opt.function_references,
259259
"python_version": platform.python_version(),
260260
}
261261
for opt in request
262262
]
263-
logger.debug(f"Refining {len(request)} optimizations…")
264-
console.rule()
265263
try:
266264
response = self.make_ai_service_request("/refinement", payload=payload, timeout=120)
267265
except requests.exceptions.RequestException as e:
@@ -271,8 +269,6 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
271269

272270
if response.status_code == 200:
273271
refined_optimizations = response.json()["refinements"]
274-
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
275-
console.rule()
276272

277273
refinements = self._get_valid_candidates(refined_optimizations)
278274
return [

codeflash/code_utils/code_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,63 @@ def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tof
4141
return "".join(diff)
4242

4343

44+
def choose_weights(**importance: float) -> list[float]:
45+
"""Choose normalized weights from relative importance values.
46+
47+
Example:
48+
choose_weights(runtime=3, diff=1)
49+
-> [0.75, 0.25]
50+
51+
Args:
52+
**importance: keyword args of metric=importance (relative numbers).
53+
54+
Returns:
55+
A list of weights in the same order as the arguments.
56+
57+
"""
58+
total = sum(importance.values())
59+
if total == 0:
60+
raise ValueError("At least one importance value must be > 0")
61+
62+
return [v / total for v in importance.values()]
63+
64+
65+
def normalize(values: list[float]) -> list[float]:
66+
mn, mx = min(values), max(values)
67+
if mx == mn:
68+
return [0.0] * len(values)
69+
return [(v - mn) / (mx - mn) for v in values]
70+
71+
72+
def create_score_dictionary_from_metrics(weights: list[float], *metrics: list[float]) -> dict[int, int]:
73+
"""Combine multiple metrics into a single weighted score dictionary.
74+
75+
Each metric is a list of values (smaller = better).
76+
The total score for each index is the weighted sum of its values
77+
across all metrics:
78+
79+
score[index] = Σ (value * weight)
80+
81+
Args:
82+
weights: A list of weights, one per metric. Larger weight = more influence.
83+
*metrics: Lists of values (one list per metric, aligned by index).
84+
85+
Returns:
86+
A dictionary mapping each index to its combined weighted score.
87+
88+
"""
89+
if len(weights) != len(metrics):
90+
raise ValueError("Number of weights must match number of metrics")
91+
92+
combined: dict[int, float] = {}
93+
94+
for weight, metric in zip(weights, metrics):
95+
for idx, value in enumerate(metric):
96+
combined[idx] = combined.get(idx, 0) + value * weight
97+
98+
return combined
99+
100+
44101
def diff_length(a: str, b: str) -> int:
45102
"""Compute the length (in characters) of the unified diff between two strings.
46103

codeflash/code_utils/config_consts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
1515
N_CANDIDATES_LP = 6
1616

17+
# Refinement
18+
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations
19+
REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2
20+
TOP_N_REFINEMENTS = 0.45 # top 45% of valid optimizations (based on the weighted score) are refined
21+
1722
# LSP-specific
1823
N_CANDIDATES_LSP = 3
1924
N_TESTS_TO_GENERATE_LSP = 2

codeflash/models/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ class AIServiceRefinerRequest:
3636
optimization_id: str
3737
original_source_code: str
3838
read_only_dependency_code: str
39-
original_code_runtime: str
39+
original_code_runtime: int
4040
optimized_source_code: str
4141
optimized_explanation: str
42-
optimized_code_runtime: str
42+
optimized_code_runtime: int
4343
speedup: str
4444
trace_id: str
4545
original_line_profiler_results: str

codeflash/optimization/function_optimizer.py

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@
2929
replace_function_definitions_in_module,
3030
)
3131
from codeflash.code_utils.code_utils import (
32+
choose_weights,
3233
cleanup_paths,
3334
create_rank_dictionary_compact,
35+
create_score_dictionary_from_metrics,
3436
diff_length,
3537
extract_unique_errors,
3638
file_name_from_test_module_name,
3739
get_run_tmp_file,
3840
module_name_from_file_path,
41+
normalize,
3942
restore_conftest,
4043
unified_diff_strings,
4144
)
@@ -45,7 +48,10 @@
4548
N_CANDIDATES_EFFECTIVE,
4649
N_CANDIDATES_LP_EFFECTIVE,
4750
N_TESTS_TO_GENERATE_EFFECTIVE,
51+
REFINE_ALL_THRESHOLD,
52+
REFINED_CANDIDATE_RANKING_WEIGHTS,
4853
REPEAT_OPTIMIZATION_PROBABILITY,
54+
TOP_N_REFINEMENTS,
4955
TOTAL_LOOPING_TIME_EFFECTIVE,
5056
)
5157
from codeflash.code_utils.deduplicate_code import normalize_code
@@ -124,19 +130,23 @@ def __init__(
124130
self,
125131
initial_candidates: list,
126132
future_line_profile_results: concurrent.futures.Future,
127-
future_all_refinements: list,
133+
all_refinements_data: list[AIServiceRefinerRequest],
134+
ai_service_client: AiServiceClient,
135+
executor: concurrent.futures.ThreadPoolExecutor,
128136
) -> None:
129137
self.candidate_queue = queue.Queue()
130138
self.line_profiler_done = False
131139
self.refinement_done = False
132140
self.candidate_len = len(initial_candidates)
141+
self.ai_service_client = ai_service_client
142+
self.executor = executor
133143

134144
# Initialize queue with initial candidates
135145
for candidate in initial_candidates:
136146
self.candidate_queue.put(candidate)
137147

138148
self.future_line_profile_results = future_line_profile_results
139-
self.future_all_refinements = future_all_refinements
149+
self.all_refinements_data = all_refinements_data
140150

141151
def get_next_candidate(self) -> OptimizedCandidate | None:
142152
"""Get the next candidate from the queue, handling async results as needed."""
@@ -168,15 +178,45 @@ def _process_line_profiler_results(self) -> OptimizedCandidate | None:
168178

169179
return self.get_next_candidate()
170180

181+
def refine_optimizations(self, request: list[AIServiceRefinerRequest]) -> concurrent.futures.Future:
182+
return self.executor.submit(self.ai_service_client.optimize_python_code_refinement, request=request)
183+
171184
def _process_refinement_results(self) -> OptimizedCandidate | None:
172-
"""Process refinement results and add to queue."""
173-
if self.future_all_refinements:
185+
"""Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined."""
186+
future_refinements: list[concurrent.futures.Future] = []
187+
188+
if len(self.all_refinements_data) <= REFINE_ALL_THRESHOLD:
189+
for data in self.all_refinements_data:
190+
future_refinements.append(self.refine_optimizations([data])) # noqa: PERF401
191+
else:
192+
diff_lens_list = []
193+
runtimes_list = []
194+
for c in self.all_refinements_data:
195+
diff_lens_list.append(diff_length(c.original_source_code, c.optimized_source_code))
196+
runtimes_list.append(c.optimized_code_runtime)
197+
198+
runtime_w, diff_w = REFINED_CANDIDATE_RANKING_WEIGHTS
199+
weights = choose_weights(runtime=runtime_w, diff=diff_w)
200+
201+
runtime_norm = normalize(runtimes_list)
202+
diffs_norm = normalize(diff_lens_list)
203+
# the lower the better
204+
score_dict = create_score_dictionary_from_metrics(weights, runtime_norm, diffs_norm)
205+
top_n_candidates = int((TOP_N_REFINEMENTS * len(runtimes_list)) + 0.5)
206+
top_indecies = sorted(score_dict, key=score_dict.get)[:top_n_candidates]
207+
208+
for idx in top_indecies:
209+
data = self.all_refinements_data[idx]
210+
future_refinements.append(self.refine_optimizations([data]))
211+
212+
if future_refinements:
174213
logger.info("loading|Refining generated code for improved quality and performance...")
175-
concurrent.futures.wait(self.future_all_refinements)
214+
215+
concurrent.futures.wait(future_refinements)
176216
refinement_response = []
177217

178-
for future_refinement in self.future_all_refinements:
179-
possible_refinement = future_refinement.result()
218+
for f in future_refinements:
219+
possible_refinement = f.result()
180220
if len(possible_refinement) > 0:
181221
refinement_response.append(possible_refinement[0])
182222

@@ -684,15 +724,14 @@ def process_single_candidate(
684724
original_helper_code: dict[Path, str],
685725
file_path_to_helper_classes: dict[Path, set[str]],
686726
eval_ctx: CandidateEvaluationContext,
687-
future_all_refinements: list[concurrent.futures.Future],
688-
ai_service_client: AiServiceClient,
727+
all_refinements_data: list[AIServiceRefinerRequest],
689728
exp_type: str,
690729
function_references: str,
691730
) -> BestOptimization | None:
692731
"""Process a single optimization candidate.
693732
694733
Returns the BestOptimization if the candidate is successful, None otherwise.
695-
Updates eval_ctx with results and may append to future_all_refinements.
734+
Updates eval_ctx with results and may append to all_refinements_data.
696735
"""
697736
# Cleanup temp files
698737
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
@@ -787,14 +826,19 @@ def process_single_candidate(
787826

788827
# Queue refinement for non-refined candidates
789828
if not candidate.optimization_id.endswith("refi"):
790-
future_all_refinements.append(
791-
self.refine_optimizations(
792-
valid_optimizations=[best_optimization],
793-
original_code_baseline=original_code_baseline,
794-
code_context=code_context,
829+
all_refinements_data.append(
830+
AIServiceRefinerRequest(
831+
optimization_id=best_optimization.candidate.optimization_id,
832+
original_source_code=code_context.read_writable_code.markdown,
833+
read_only_dependency_code=code_context.read_only_context_code,
834+
original_code_runtime=original_code_baseline.runtime,
835+
optimized_source_code=best_optimization.candidate.source_code.markdown,
836+
optimized_explanation=best_optimization.candidate.explanation,
837+
optimized_code_runtime=best_optimization.runtime,
838+
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%",
795839
trace_id=self.get_trace_id(exp_type),
796-
ai_service_client=ai_service_client,
797-
executor=self.executor,
840+
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
841+
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
798842
function_references=function_references,
799843
)
800844
)
@@ -830,7 +874,7 @@ def determine_best_candidate(
830874

831875
# Initialize evaluation context and async tasks
832876
eval_ctx = CandidateEvaluationContext()
833-
future_all_refinements: list[concurrent.futures.Future] = []
877+
all_refinements_data: list[AIServiceRefinerRequest] = []
834878
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
835879
assert ai_service_client is not None, "AI service client must be set for optimization"
836880

@@ -848,7 +892,9 @@ def determine_best_candidate(
848892
else None,
849893
)
850894

851-
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
895+
processor = CandidateProcessor(
896+
candidates, future_line_profile_results, all_refinements_data, self.aiservice_client, self.executor
897+
)
852898
candidate_index = 0
853899

854900
# Process candidates using queue-based approach
@@ -869,8 +915,7 @@ def determine_best_candidate(
869915
original_helper_code=original_helper_code,
870916
file_path_to_helper_classes=file_path_to_helper_classes,
871917
eval_ctx=eval_ctx,
872-
future_all_refinements=future_all_refinements,
873-
ai_service_client=ai_service_client,
918+
all_refinements_data=all_refinements_data,
874919
exp_type=exp_type,
875920
function_references=function_references,
876921
)
@@ -903,35 +948,6 @@ def determine_best_candidate(
903948

904949
return best_optimization
905950

906-
def refine_optimizations(
907-
self,
908-
valid_optimizations: list[BestOptimization],
909-
original_code_baseline: OriginalCodeBaseline,
910-
code_context: CodeOptimizationContext,
911-
trace_id: str,
912-
ai_service_client: AiServiceClient,
913-
executor: concurrent.futures.ThreadPoolExecutor,
914-
function_references: str | None = None,
915-
) -> concurrent.futures.Future:
916-
request = [
917-
AIServiceRefinerRequest(
918-
optimization_id=opt.candidate.optimization_id,
919-
original_source_code=code_context.read_writable_code.markdown,
920-
read_only_dependency_code=code_context.read_only_context_code,
921-
original_code_runtime=humanize_runtime(original_code_baseline.runtime),
922-
optimized_source_code=opt.candidate.source_code.markdown,
923-
optimized_explanation=opt.candidate.explanation,
924-
optimized_code_runtime=humanize_runtime(opt.runtime),
925-
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=opt.runtime) * 100)}%",
926-
trace_id=trace_id,
927-
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
928-
optimized_line_profiler_results=opt.line_profiler_test_results["str_out"],
929-
function_references=function_references,
930-
)
931-
for opt in valid_optimizations
932-
]
933-
return executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
934-
935951
def log_successful_optimization(
936952
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
937953
) -> None:

0 commit comments

Comments
 (0)