Skip to content

Refinement #555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7b47a37
WIP refiner
misrasaurabh1 Jul 15, 2025
9da396b
some fixes
misrasaurabh1 Jul 15, 2025
9627f73
change url
aseembits93 Jul 15, 2025
9826b00
fixes
misrasaurabh1 Jul 15, 2025
50cf370
diff format not working yet
aseembits93 Jul 16, 2025
58e44d3
get some heuristic working for best optimization
aseembits93 Jul 16, 2025
77ed5c8
working dirty implementation of ranked choice voting for finding best…
aseembits93 Jul 16, 2025
5be61da
First working version of the refiner
misrasaurabh1 Jul 17, 2025
7e4ba6e
merge
misrasaurabh1 Jul 17, 2025
65d2971
add RO context
misrasaurabh1 Jul 17, 2025
e77da5c
bugfix
aseembits93 Jul 17, 2025
19cd5c8
bugfix
aseembits93 Jul 17, 2025
5aab3b8
bugfix
aseembits93 Jul 17, 2025
ed6b5b1
send tracked refinement optimization data
mohammedahmed18 Jul 18, 2025
1e7a7cb
marker for refinement pr
aseembits93 Jul 21, 2025
3eedbd2
refi optimization ids and original optimization ids
mohammedahmed18 Jul 22, 2025
3562394
Merge branch 'refinement' of github.com:codeflash-ai/codeflash into r…
mohammedahmed18 Jul 22, 2025
42f0ada
send the best optimization id only - not the whole metadata object
mohammedahmed18 Jul 23, 2025
e964ca6
works now, todo tiebreaking for same ranks
aseembits93 Jul 25, 2025
beb1ee0
precommit mypy fix
aseembits93 Jul 25, 2025
c4cf495
Merge remote-tracking branch 'origin/main' into refinement
aseembits93 Jul 25, 2025
c0b85ad
cleaning up
aseembits93 Jul 26, 2025
ef80323
further streamlining
aseembits93 Jul 26, 2025
65d766d
bugfix
aseembits93 Jul 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 78 additions & 7 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
from codeflash.models.models import OptimizedCandidate
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate
from codeflash.telemetry.posthog_cf import ph
from codeflash.version import __version__ as codeflash_version

Expand All @@ -21,6 +22,7 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import AIServiceRefinerRequest


class AiServiceClient:
Expand All @@ -36,7 +38,11 @@ def get_aiservice_base_url(self) -> str:
return "https://app.codeflash.ai"

def make_ai_service_request(
self, endpoint: str, method: str = "POST", payload: dict[str, Any] | None = None, timeout: float | None = None
self,
endpoint: str,
method: str = "POST",
payload: dict[str, Any] | list[dict[str, Any]] | None = None,
timeout: float | None = None,
) -> requests.Response:
"""Make an API request to the given endpoint on the AI service.

Expand Down Expand Up @@ -98,11 +104,7 @@ def optimize_python_code( # noqa: D417

"""
start_time = time.perf_counter()
try:
git_repo_owner, git_repo_name = get_repo_owner_and_name()
except Exception as e:
logger.warning(f"Could not determine repo owner and name: {e}")
git_repo_owner, git_repo_name = None, None
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()

payload = {
"source_code": source_code,
Expand Down Expand Up @@ -219,13 +221,71 @@ def optimize_python_code_line_profiler( # noqa: D417
console.rule()
return []

def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
"""Optimize the given python code for performance by making a request to the Django endpoint.

Args:
request: A list of optimization candidate details for refinement

Returns:
-------
- List[OptimizationCandidate]: A list of Optimization Candidates.

"""
payload = [
{
"optimization_id": opt.optimization_id,
"original_source_code": opt.original_source_code,
"read_only_dependency_code": opt.read_only_dependency_code,
"original_line_profiler_results": opt.original_line_profiler_results,
"original_code_runtime": opt.original_code_runtime,
"optimized_source_code": opt.optimized_source_code,
"optimized_explanation": opt.optimized_explanation,
"optimized_line_profiler_results": opt.optimized_line_profiler_results,
"optimized_code_runtime": opt.optimized_code_runtime,
"speedup": opt.speedup,
"trace_id": opt.trace_id,
}
for opt in request
]
logger.info(f"Refining {len(request)} optimizations…")
console.rule()
try:
response = self.make_ai_service_request("/refinement", payload=payload, timeout=600)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating optimization refinements: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
return []

if response.status_code == 200:
refined_optimizations = response.json()["refinements"]
logger.info(f"Generated {len(refined_optimizations)} candidate refinements.")
console.rule()
return [
OptimizedCandidate(
source_code=opt["source_code"],
explanation=opt["explanation"],
optimization_id=opt["optimization_id"][:-4] + "refi",
)
for opt in refined_optimizations
]
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
console.rule()
return []

def log_results( # noqa: D417
self,
function_trace_id: str,
speedup_ratio: dict[str, float | None] | None,
original_runtime: float | None,
optimized_runtime: dict[str, float | None] | None,
is_correct: dict[str, bool] | None,
optimized_line_profiler_results: dict[str, str] | None,
) -> None:
"""Log features to the database.

Expand All @@ -236,6 +296,7 @@ def log_results( # noqa: D417
- original_runtime (Optional[Dict[str, float]]): The original runtime.
- optimized_runtime (Optional[Dict[str, float]]): The optimized runtime.
- is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
-optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id

"""
payload = {
Expand All @@ -245,6 +306,7 @@ def log_results( # noqa: D417
"optimized_runtime": optimized_runtime,
"is_correct": is_correct,
"codeflash_version": codeflash_version,
"optimized_line_profiler_results": optimized_line_profiler_results,
}
try:
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
Expand Down Expand Up @@ -331,3 +393,12 @@ class LocalAiServiceClient(AiServiceClient):
def get_aiservice_base_url(self) -> str:
"""Get the base URL for the local AI service."""
return "http://localhost:8000"


def safe_get_repo_owner_and_name() -> tuple[str | None, str | None]:
try:
git_repo_owner, git_repo_name = get_repo_owner_and_name()
except Exception as e:
logger.warning(f"Could not determine repo owner and name: {e}")
git_repo_owner, git_repo_name = None, None
return git_repo_owner, git_repo_name
45 changes: 45 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
import difflib
import os
import re
import shutil
Expand All @@ -19,6 +20,50 @@
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)


def diff_length(a: str, b: str) -> int:
"""Compute the length (in characters) of the unified diff between two strings.

Args:
a (str): Original string.
b (str): Modified string.

Returns:
int: Total number of characters in the diff.

"""
# Split input strings into lines for line-by-line diff
a_lines = a.splitlines(keepends=True)
b_lines = b.splitlines(keepends=True)

# Compute unified diff
diff_lines = list(difflib.unified_diff(a_lines, b_lines, lineterm=""))

# Join all lines with newline to calculate total diff length
diff_text = "\n".join(diff_lines)

return len(diff_text)


def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
"""Create a dictionary from a list of ints, mapping the original index to its rank.

This version uses a more compact, "Pythonic" implementation.

Args:
int_array: A list of integers.

Returns:
A dictionary where keys are original indices and values are the
rank of the element in ascending order.

"""
# Sort the indices of the array based on their corresponding values
sorted_indices = sorted(range(len(int_array)), key=lambda i: int_array[i])

# Create a dictionary mapping the original index to its rank (its position in the sorted list)
return {original_index: rank for rank, original_index in enumerate(sorted_indices)}


@contextmanager
def custom_addopts() -> None:
pyproject_file = find_pyproject_toml()
Expand Down
18 changes: 18 additions & 0 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@
from codeflash.code_utils.env_utils import is_end_to_end
from codeflash.verification.comparator import comparator


@dataclass(frozen=True)
class AIServiceRefinerRequest:
optimization_id: str
original_source_code: str
read_only_dependency_code: str
original_code_runtime: str
optimized_source_code: str
optimized_explanation: str
optimized_code_runtime: str
speedup: str
trace_id: str
original_line_profiler_results: str
optimized_line_profiler_results: str


# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
# of the module is foo.eggs.
Expand Down Expand Up @@ -76,11 +92,13 @@ def __hash__(self) -> int:
class BestOptimization(BaseModel):
candidate: OptimizedCandidate
helper_functions: list[FunctionSource]
code_context: CodeOptimizationContext
runtime: int
replay_performance_gain: Optional[dict[BenchmarkKey, float]] = None
winning_behavior_test_results: TestResults
winning_benchmarking_test_results: TestResults
winning_replay_benchmarking_test_results: Optional[TestResults] = None
line_profiler_test_results: dict


@dataclass(frozen=True)
Expand Down
Loading
Loading