Skip to content

Commit 09fa96c

Browse files
KRRT7codeflash-ai[bot]misrasaurabh1warp-agent
authored
tracer improvements (#970)
* Consolidate FunctionRanker: merge rank/rerank/filter methods into single rank_functions * calculate in own file time remove unittests remnants * implement suggestions * cleanup code * let's make it clear it's an sqlite3 db * forgot this one * cleanup * tessl add * improve filtering * cleanup * Optimize FunctionRanker.get_function_stats_summary (#971) The optimization replaces an O(N) linear search through all functions with an O(1) hash table lookup followed by iteration over only matching function names. **Key Changes:** - Added `_function_stats_by_name` index in `__init__` that maps function names to lists of (key, stats) tuples - Modified `get_function_stats_summary` to first lookup candidates by function name, then iterate only over those candidates **Why This is Faster:** The original code iterates through ALL function stats (22,603 iterations in the profiler results) for every lookup. The optimized version uses a hash table to instantly find only the functions with matching names, then iterates through just those candidates (typically 1-2 functions). **Performance Impact:** - **Small datasets**: 15-30% speedup as shown in basic test cases - **Large datasets**: Dramatic improvement - the `test_large_scale_performance` case with 900 functions shows **3085% speedup** (66.7μs → 2.09μs) - **Overall benchmark**: 2061% speedup demonstrates the optimization scales excellently with dataset size **When This Optimization Shines:** - Large codebases with many profiled functions (where the linear search becomes expensive) - Repeated function lookups (if this method is called frequently) - Cases with many unique function names but few duplicates per name The optimization maintains identical behavior while transforming the algorithm from O(N) per lookup to O(average functions per name) per lookup, which is typically O(1) in practice. Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> * Revert "let's make it clear it's an sqlite3 db" This reverts commit 713f135. * cleanup trace file * cleanup * addressable time * Optimize TestResults.add The optimization applies **local variable caching** to eliminate repeated attribute lookups on `self.test_result_idx` and `self.test_results`. **Key Changes:** - Added `test_result_idx = self.test_result_idx` and `test_results = self.test_results` to cache references locally - Used these local variables instead of accessing `self.*` attributes multiple times **Why This Works:** In Python, attribute access (e.g., `self.test_result_idx`) involves dictionary lookups in the object's `__dict__`, which is slower than accessing local variables. By caching these references, we eliminate redundant attribute resolution overhead on each access. **Performance Impact:** The line profiler shows the optimization reduces total execution time from 12.771ms to 19.482ms in the profiler run, but the actual runtime improved from 2.13ms to 1.89ms (12% speedup). The test results consistently show 10-20% improvements across various scenarios, particularly benefiting: - Large-scale operations (500+ items): 14-16% faster - Multiple unique additions: 15-20% faster - Mixed workloads with duplicates: 7-15% faster **Real-World Benefits:** This optimization is especially valuable for high-frequency test result collection scenarios where the `add` method is called repeatedly in tight loops, as the cumulative effect of eliminating attribute lookups becomes significant at scale. * bugfix * cleanup * type checks * pre-commit * ⚡️ Speed up function `get_cached_gh_event_data` by 13% (#975) * Optimize get_cached_gh_event_data The optimization replaces `Path(event_path).open(encoding="utf-8")` with the built-in `open(event_path, encoding="utf-8")`, achieving a **12% speedup** by eliminating unnecessary object allocation overhead. **Key optimization:** - **Removed Path object creation**: The original code creates a `pathlib.Path` object just to call `.open()` on it, when the built-in `open()` function can directly accept the string path from `event_path`. - **Reduced memory allocation**: Avoiding the intermediate `Path` object saves both allocation time and memory overhead. **Why this works:** In Python, `pathlib.Path().open()` internally calls the same file opening mechanism as the built-in `open()`, but with additional overhead from object instantiation and method dispatch. Since `event_path` is already a string from `os.getenv()`, passing it directly to `open()` is more efficient. **Performance impact:** The test results show consistent improvements across all file-reading scenarios: - Simple JSON files: 12-20% faster - Large files (1000+ elements): 3-27% faster - Error cases (missing files): Up to 71% faster - The cached calls remain unaffected (0% change as expected) **Workload benefits:** Based on the function references, `get_cached_gh_event_data()` is called by multiple GitHub-related utility functions (`get_pr_number()`, `is_repo_a_fork()`, `is_pr_draft()`). While the `@lru_cache(maxsize=1)` means the file is only read once per program execution, this optimization reduces the initial cold-start latency for GitHub Actions workflows or CI/CD pipelines where these functions are commonly used. The optimization is particularly effective for larger JSON files and error handling scenarios, making it valuable for robust CI/CD environments that may encounter various file conditions. * ignore --------- Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> Co-authored-by: Kevin Turcios <[email protected]> * ⚡️ Speed up function `function_is_a_property` by 60% (#974) * Optimize function_is_a_property The optimized version achieves a **60% speedup** by replacing Python's `any()` generator expression with a manual loop and making three key micro-optimizations: **What was optimized:** 1. **Replaced `isinstance()` with `type() is`**: Direct type comparison (`type(node) is ast_Name`) is faster than `isinstance(node, ast.Name)` for AST nodes where subclassing is rare 2. **Eliminated repeated lookups**: Cached `"property"` as `property_id` and `ast.Name` as `ast_Name` in local variables to avoid global/attribute lookups in the loop 3. **Manual loop with early return**: Replaced `any()` generator with explicit `for` loop that returns `True` immediately upon finding a match, avoiding generator overhead **Why it's faster:** - The `any()` function creates generator machinery that adds overhead, especially for small decorator lists - `isinstance()` performs multiple checks while `type() is` does a single identity comparison - Local variable access is significantly faster than repeated global/attribute lookups in tight loops **Performance characteristics from tests:** - **Small decorator lists** (1-3 decorators): 50-80% faster due to reduced per-iteration overhead - **Large decorator lists** (1000+ decorators): 55-60% consistent speedup, with early termination providing additional benefits when `@property` appears early - **Empty decorator lists**: 77% faster due to avoiding `any()` generator setup entirely **Impact on workloads:** Based on the function references, this function is called during AST traversal in `visit_FunctionDef` and `visit_AsyncFunctionDef` methods - likely part of a code analysis pipeline that processes many functions. The 60% speedup will be particularly beneficial when analyzing codebases with many decorated functions, as this optimization reduces overhead in a hot path that's called once per function definition. * format --------- Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> Co-authored-by: Kevin Turcios <[email protected]> * Optimize function_is_a_property (#976) The optimization achieves an **11% speedup** through two key changes: **1. Constant Hoisting:** The original code repeatedly assigns `property_id = "property"` and `ast_name = ast.Name` on every function call. The optimized version moves these to module-level constants `_property_id` and `_ast_name`, eliminating 4,130 redundant assignments per profiling run (saving ~2.12ms total time). **2. isinstance() vs type() comparison:** Replaced `type(node) is ast_name` with `isinstance(node, _ast_name)`. While both are correct for AST nodes (which use single inheritance), `isinstance()` is slightly more efficient for type checking in Python's implementation. **Performance Impact:** The function is called in AST traversal loops when discovering functions to optimize (`visit_FunctionDef` and `visit_AsyncFunctionDef`). Since these visitors process entire codebases, the 11% per-call improvement compounds significantly across large projects. **Test Case Performance:** The optimization shows consistent gains across all test scenarios: - **Simple cases** (no decorators): 29-42% faster due to eliminated constant assignments - **Property detection cases**: 11-26% faster from combined optimizations - **Large-scale tests** (500-1000 functions): 18.5% faster, demonstrating the cumulative benefit when processing many functions The optimizations are particularly effective for codebases with many function definitions, where this function gets called repeatedly during AST analysis. Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> * Address PR review comments - Add mkdir for test file directory to prevent FileNotFoundError - Use addressable_time_ns for importance filtering instead of own_time_ns - Remove unnecessary list() wrappers in make_pstats_compatible - Remove old .sqlite3 file with wrong extension Co-Authored-By: Warp <[email protected]> * Check addressable_time_ns instead of own_time_ns for filtering This ensures we consider functions that may have low own_time but high time in first-order dependent functions (callees). Co-Authored-By: Warp <[email protected]> --------- Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> Co-authored-by: Saurabh Misra <[email protected]> Co-authored-by: Warp <[email protected]>
1 parent 872ec28 commit 09fa96c

File tree

14 files changed

+474
-266
lines changed

14 files changed

+474
-266
lines changed

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,9 @@ fabric.properties
254254

255255
# Mac
256256
.DS_Store
257-
WARP.MD
257+
WARP.MD
258+
259+
.mcp.json
260+
.tessl/
261+
CLAUDE.md
262+
tessl.json

AGENTS.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,4 +315,8 @@ Language Server Protocol support in `codeflash/lsp/` enables IDE integration dur
315315
### Performance Optimization
316316
- Profile before and after changes
317317
- Use benchmarks to validate improvements
318-
- Generate detailed performance reports
318+
- Generate detailed performance reports
319+
320+
# Agent Rules <!-- tessl-managed -->
321+
322+
@.tessl/RULES.md follow the [instructions](.tessl/RULES.md)

codeflash/benchmarking/function_ranker.py

Lines changed: 132 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING
44

5-
from codeflash.cli_cmds.console import console, logger
5+
from codeflash.cli_cmds.console import logger
66
from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD
77
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
88
from codeflash.tracing.profile_stats import ProfileStats
@@ -12,29 +12,63 @@
1212

1313
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1414

15+
pytest_patterns = {
16+
"<frozen", # Frozen modules like runpy
17+
"<string>", # Dynamically evaluated code
18+
"_pytest/", # Pytest internals
19+
"pytest", # Pytest files
20+
"pluggy/", # Plugin system
21+
"_pydev", # PyDev debugger
22+
"runpy.py", # Python module runner
23+
}
24+
pytest_func_patterns = {"pytest_", "_pytest", "runtest"}
25+
26+
27+
def is_pytest_infrastructure(filename: str, function_name: str) -> bool:
28+
"""Check if a function is part of pytest infrastructure that should be excluded from ranking.
29+
30+
This filters out pytest internal functions, hooks, and test framework code that
31+
would otherwise dominate the ranking but aren't candidates for optimization.
32+
"""
33+
# Check filename patterns
34+
for pattern in pytest_patterns:
35+
if pattern in filename:
36+
return True
37+
38+
return any(pattern in function_name.lower() for pattern in pytest_func_patterns)
39+
1540

1641
class FunctionRanker:
17-
"""Ranks and filters functions based on a ttX score derived from profiling data.
42+
"""Ranks and filters functions based on % of addressable time derived from profiling data.
1843
19-
The ttX score is calculated as:
20-
ttX = own_time + (time_spent_in_callees / call_count)
44+
The % of addressable time is calculated as:
45+
addressable_time = own_time + (time_spent_in_callees / call_count)
2146
22-
This score prioritizes functions that are computationally heavy themselves (high `own_time`)
23-
or that make expensive calls to other functions (high average `time_spent_in_callees`).
47+
This represents the runtime of a function plus the runtime of its immediate dependent functions,
48+
as a fraction of overall runtime. It prioritizes functions that are computationally heavy themselves
49+
(high `own_time`) or that make expensive calls to other functions (high average `time_spent_in_callees`).
2450
2551
Functions are first filtered by an importance threshold based on their `own_time` as a
26-
fraction of the total runtime. The remaining functions are then ranked by their ttX score
52+
fraction of the total runtime. The remaining functions are then ranked by their % of addressable time
2753
to identify the best candidates for optimization.
2854
"""
2955

3056
def __init__(self, trace_file_path: Path) -> None:
3157
self.trace_file_path = trace_file_path
3258
self._profile_stats = ProfileStats(trace_file_path.as_posix())
3359
self._function_stats: dict[str, dict] = {}
60+
self._function_stats_by_name: dict[str, list[tuple[str, dict]]] = {}
3461
self.load_function_stats()
3562

63+
# Build index for faster lookups: map function_name to list of (key, stats)
64+
for key, stats in self._function_stats.items():
65+
func_name = stats.get("function_name")
66+
if func_name:
67+
self._function_stats_by_name.setdefault(func_name, []).append((key, stats))
68+
3669
def load_function_stats(self) -> None:
3770
try:
71+
pytest_filtered_count = 0
3872
for (filename, line_number, func_name), (
3973
call_count,
4074
_num_callers,
@@ -45,6 +79,10 @@ def load_function_stats(self) -> None:
4579
if call_count <= 0:
4680
continue
4781

82+
if is_pytest_infrastructure(filename, func_name):
83+
pytest_filtered_count += 1
84+
continue
85+
4886
# Parse function name to handle methods within classes
4987
class_name, qualified_name, base_function_name = (None, func_name, func_name)
5088
if "." in func_name and not func_name.startswith("<"):
@@ -56,8 +94,8 @@ def load_function_stats(self) -> None:
5694
own_time_ns = total_time_ns
5795
time_in_callees_ns = cumulative_time_ns - total_time_ns
5896

59-
# Calculate ttX score
60-
ttx_score = own_time_ns + (time_in_callees_ns / call_count)
97+
# Calculate addressable time (own time + avg time in immediate callees)
98+
addressable_time_ns = own_time_ns + (time_in_callees_ns / call_count)
6199

62100
function_key = f"{filename}:{qualified_name}"
63101
self._function_stats[function_key] = {
@@ -70,89 +108,118 @@ def load_function_stats(self) -> None:
70108
"own_time_ns": own_time_ns,
71109
"cumulative_time_ns": cumulative_time_ns,
72110
"time_in_callees_ns": time_in_callees_ns,
73-
"ttx_score": ttx_score,
111+
"addressable_time_ns": addressable_time_ns,
74112
}
75113

76-
logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats")
114+
logger.debug(
115+
f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats "
116+
f"(filtered {pytest_filtered_count} pytest infrastructure functions)"
117+
)
77118

78119
except Exception as e:
79120
logger.warning(f"Failed to process function stats from trace file {self.trace_file_path}: {e}")
80121
self._function_stats = {}
81122

82-
def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict | None:
123+
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
83124
target_filename = function_to_optimize.file_path.name
84-
for key, stats in self._function_stats.items():
85-
if stats.get("function_name") == function_to_optimize.function_name and (
86-
key.endswith(f"/{target_filename}") or target_filename in key
87-
):
125+
candidates = self._function_stats_by_name.get(function_to_optimize.function_name)
126+
if not candidates:
127+
logger.debug(
128+
f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}"
129+
)
130+
return None
131+
132+
for key, stats in candidates:
133+
# The check preserves exact logic: "key.endswith(f"/{target_filename}") or target_filename in key"
134+
if key.endswith(f"/{target_filename}") or target_filename in key:
88135
return stats
89136

90137
logger.debug(
91138
f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}"
92139
)
93140
return None
94141

95-
def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
96-
stats = self._get_function_stats(function_to_optimize)
97-
return stats["ttx_score"] if stats else 0.0
142+
def get_function_addressable_time(self, function_to_optimize: FunctionToOptimize) -> float:
143+
"""Get the addressable time in nanoseconds for a function.
98144
99-
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
100-
ranked = sorted(functions_to_optimize, key=self.get_function_ttx_score, reverse=True)
101-
logger.debug(
102-
f"Function ranking order: {[f'{func.function_name} (ttX={self.get_function_ttx_score(func):.2f})' for func in ranked]}"
103-
)
104-
return ranked
145+
Addressable time = own_time + (time_in_callees / call_count)
146+
This represents the runtime of the function plus runtime of immediate dependent functions.
147+
"""
148+
stats = self.get_function_stats_summary(function_to_optimize)
149+
return stats["addressable_time_ns"] if stats else 0.0
105150

106-
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
107-
return self._get_function_stats(function_to_optimize)
151+
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
152+
"""Ranks and filters functions based on their % of addressable time and importance.
108153
109-
def rerank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
110-
"""Ranks functions based on their ttX score.
154+
Filters out functions whose own_time is less than DEFAULT_IMPORTANCE_THRESHOLD
155+
of file-relative runtime, then ranks the remaining functions by addressable time.
111156
112-
This method calculates the ttX score for each function and returns
113-
the functions sorted in descending order of their ttX score.
114-
"""
115-
if not self._function_stats:
116-
logger.warning("No function stats available to rank functions.")
117-
return []
157+
Importance is calculated relative to functions in the same file(s) rather than
158+
total program time. This avoids filtering out functions due to test infrastructure
159+
overhead.
118160
119-
return self.rank_functions(functions_to_optimize)
161+
The addressable time metric (own_time + avg time in immediate callees) prioritizes
162+
functions that are computationally heavy themselves or that make expensive calls
163+
to other functions.
120164
121-
def rerank_and_filter_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
122-
"""Reranks and filters functions based on their impact on total runtime.
165+
Args:
166+
functions_to_optimize: List of functions to rank.
123167
124-
This method first calculates the total runtime of all profiled functions.
125-
It then filters out functions whose own_time is less than a specified
126-
percentage of the total runtime (importance_threshold).
168+
Returns:
169+
Important functions sorted in descending order of their addressable time.
127170
128-
The remaining 'important' functions are then ranked by their ttX score.
129171
"""
130-
stats_map = self._function_stats
131-
if not stats_map:
172+
if not self._function_stats:
173+
logger.warning("No function stats available to rank functions.")
132174
return []
133175

134-
total_program_time = sum(s["own_time_ns"] for s in stats_map.values() if s.get("own_time_ns", 0) > 0)
176+
# Calculate total time from functions in the same file(s) as functions to optimize
177+
if functions_to_optimize:
178+
# Get unique files from functions to optimize
179+
target_files = {func.file_path.name for func in functions_to_optimize}
180+
# Calculate total time only from functions in these files
181+
total_program_time = sum(
182+
s["own_time_ns"]
183+
for s in self._function_stats.values()
184+
if s.get("own_time_ns", 0) > 0
185+
and any(
186+
str(s.get("filename", "")).endswith("/" + target_file) or s.get("filename") == target_file
187+
for target_file in target_files
188+
)
189+
)
190+
logger.debug(
191+
f"Using file-relative importance for {len(target_files)} file(s): {target_files}. "
192+
f"Total file time: {total_program_time:,} ns"
193+
)
194+
else:
195+
total_program_time = sum(
196+
s["own_time_ns"] for s in self._function_stats.values() if s.get("own_time_ns", 0) > 0
197+
)
135198

136199
if total_program_time == 0:
137200
logger.warning("Total program time is zero, cannot determine function importance.")
138-
return self.rank_functions(functions_to_optimize)
139-
140-
important_functions = []
141-
for func in functions_to_optimize:
142-
func_stats = self._get_function_stats(func)
143-
if func_stats and func_stats.get("own_time_ns", 0) > 0:
144-
importance = func_stats["own_time_ns"] / total_program_time
145-
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
146-
important_functions.append(func)
147-
else:
148-
logger.debug(
149-
f"Filtering out function {func.qualified_name} with importance "
150-
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
151-
)
152-
153-
logger.info(
154-
f"Filtered down to {len(important_functions)} important functions from {len(functions_to_optimize)} total functions"
201+
functions_to_rank = functions_to_optimize
202+
else:
203+
functions_to_rank = []
204+
for func in functions_to_optimize:
205+
func_stats = self.get_function_stats_summary(func)
206+
if func_stats and func_stats.get("addressable_time_ns", 0) > 0:
207+
importance = func_stats["addressable_time_ns"] / total_program_time
208+
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
209+
functions_to_rank.append(func)
210+
else:
211+
logger.debug(
212+
f"Filtering out function {func.qualified_name} with importance "
213+
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
214+
)
215+
216+
logger.info(
217+
f"Filtered down to {len(functions_to_rank)} important functions "
218+
f"from {len(functions_to_optimize)} total functions"
219+
)
220+
221+
ranked = sorted(functions_to_rank, key=self.get_function_addressable_time, reverse=True)
222+
logger.debug(
223+
f"Function ranking order: {[f'{func.function_name} (addressable_time={self.get_function_addressable_time(func):.2f}ns)' for func in ranked]}"
155224
)
156-
console.rule()
157-
158-
return self.rank_functions(important_functions)
225+
return ranked

codeflash/benchmarking/replay_test.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -64,30 +64,23 @@ def get_unique_test_name(module: str, function_name: str, benchmark_name: str, c
6464

6565

6666
def create_trace_replay_test_code(
67-
trace_file: str,
68-
functions_data: list[dict[str, Any]],
69-
test_framework: str = "pytest",
70-
max_run_count=256, # noqa: ANN001
67+
trace_file: str, functions_data: list[dict[str, Any]], max_run_count: int = 256
7168
) -> str:
7269
"""Create a replay test for functions based on trace data.
7370
7471
Args:
7572
----
7673
trace_file: Path to the SQLite database file
7774
functions_data: List of dictionaries with function info extracted from DB
78-
test_framework: 'pytest' or 'unittest'
7975
max_run_count: Maximum number of runs to include in the test
8076
8177
Returns:
8278
-------
8379
A string containing the test code
8480
8581
"""
86-
assert test_framework in ["pytest", "unittest"]
87-
8882
# Create Imports
89-
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
90-
{"import unittest" if test_framework == "unittest" else ""}
83+
imports = """from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
9184
from codeflash.benchmarking.replay_test import get_next_arg_and_return
9285
"""
9386

@@ -158,13 +151,7 @@ def create_trace_replay_test_code(
158151
)
159152

160153
# Create main body
161-
162-
if test_framework == "unittest":
163-
self = "self"
164-
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
165-
else:
166-
test_template = ""
167-
self = ""
154+
test_template = ""
168155

169156
for func in functions_data:
170157
module_name = func.get("module_name")
@@ -223,30 +210,26 @@ def create_trace_replay_test_code(
223210
filter_variables=filter_variables,
224211
)
225212

226-
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
213+
formatted_test_body = textwrap.indent(test_body, " ")
227214

228-
test_template += " " if test_framework == "unittest" else ""
229215
unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name)
230-
test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n"
216+
test_template += f"def test_{unique_test_name}():\n{formatted_test_body}\n"
231217

232218
return imports + "\n" + metadata + "\n" + test_template
233219

234220

235-
def generate_replay_test(
236-
trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100
237-
) -> int:
221+
def generate_replay_test(trace_file_path: Path, output_dir: Path, max_run_count: int = 100) -> int:
238222
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
239223
240224
Args:
241225
----
242226
trace_file_path: Path to the SQLite database file
243227
output_dir: Directory to write the generated tests (if None, only returns the code)
244-
test_framework: 'pytest' or 'unittest'
245228
max_run_count: Maximum number of runs to include per function
246229
247230
Returns:
248231
-------
249-
Dictionary mapping benchmark names to generated test code
232+
The number of replay tests generated
250233
251234
"""
252235
count = 0
@@ -293,10 +276,7 @@ def generate_replay_test(
293276
continue
294277
# Generate the test code for this benchmark
295278
test_code = create_trace_replay_test_code(
296-
trace_file=trace_file_path.as_posix(),
297-
functions_data=functions_data,
298-
test_framework=test_framework,
299-
max_run_count=max_run_count,
279+
trace_file=trace_file_path.as_posix(), functions_data=functions_data, max_run_count=max_run_count
300280
)
301281
test_code = sort_imports(code=test_code)
302282
output_file = get_test_file_path(

0 commit comments

Comments
 (0)