diff --git a/.github/workflows/e2e-futurehouse-structure.yaml b/.github/workflows/e2e-futurehouse-structure.yaml index 71bf8804..8d33e9ea 100644 --- a/.github/workflows/e2e-futurehouse-structure.yaml +++ b/.github/workflows/e2e-futurehouse-structure.yaml @@ -12,7 +12,10 @@ jobs: # Dynamically determine if environment is needed only when workflow files change and contributor is external environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }} - runs-on: ubuntu-latest + strategy: + matrix: + os: [ubuntu-latest, windows-latest] + runs-on: ${{ matrix.os }} env: CODEFLASH_AIS_SERVER: prod POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }} @@ -22,6 +25,7 @@ jobs: RETRY_DELAY: 5 EXPECTED_IMPROVEMENT_PCT: 10 CODEFLASH_END_TO_END: 1 + PYTHONIOENCODING: "utf8" steps: - name: 🛎️ Checkout uses: actions/checkout@v4 @@ -31,6 +35,7 @@ jobs: fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} - name: Validate PR + shell: bash run: | # Check for any workflow changes if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then @@ -56,13 +61,16 @@ jobs: - name: Set up Python 3.11 for CLI uses: astral-sh/setup-uv@v5 with: - python-version: 3.11.6 + python-version: 3.11 - name: Install dependencies (CLI) run: | uv sync + uv add ruff - name: Run Codeflash to optimize code id: optimize_code + env: + PYTHONUTF8: 1 run: | - uv run python tests/scripts/end_to_end_test_futurehouse.py + uv run tests/scripts/end_to_end_test_futurehouse.py diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index 4a04b105..c0d2dcdd 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -12,8 +12,9 @@ jobs: fail-fast: false matrix: python-version: ["3.9", "3.10", "3.11", "3.12"] + os: [ubuntu-latest] continue-on-error: true - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 with: @@ -24,10 +25,29 @@ jobs: uses: astral-sh/setup-uv@v5 with: python-version: ${{ matrix.python-version }} - version: "0.5.30" - name: install dependencies run: uv sync - name: Unit tests - run: uv run pytest tests/ --benchmark-skip -m "not ci_skip" \ No newline at end of file + run: uv run pytest tests/ --benchmark-skip -m "not ci_skip" + + # unit-tests-windows: + # runs-on: windows-latest + # continue-on-error: true + # steps: + # - uses: actions/checkout@v4 + # with: + # fetch-depth: 0 + # token: ${{ secrets.GITHUB_TOKEN }} + + # - name: Install uv + # uses: astral-sh/setup-uv@v5 + # with: + # python-version: "3.11" + + # - name: install dependencies + # run: uv sync + + # - name: Unit tests + # run: uv run pytest tests/ --benchmark-skip -m "not ci_skip" diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index c2e1889d..a645b1b8 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -162,7 +162,7 @@ def create_trace_replay_test_code( benchmark_function_name=benchmark_function_name, orig_function_name=function_name, function_name=alias, - file_path=file_path, + file_path=Path(file_path).as_posix(), max_run_count=max_run_count, ) else: @@ -176,7 +176,7 @@ def create_trace_replay_test_code( test_body = test_class_method_body.format( benchmark_function_name=benchmark_function_name, orig_function_name=function_name, - file_path=file_path, + file_path=Path(file_path).as_posix(), class_name_alias=class_name_alias, class_name=class_name, method_name=method_name, @@ -187,7 +187,7 @@ def create_trace_replay_test_code( test_body = test_static_method_body.format( benchmark_function_name=benchmark_function_name, orig_function_name=function_name, - file_path=file_path, + file_path=Path(file_path).as_posix(), class_name_alias=class_name_alias, class_name=class_name, method_name=method_name, @@ -198,7 +198,7 @@ def create_trace_replay_test_code( test_body = test_method_body.format( benchmark_function_name=benchmark_function_name, orig_function_name=function_name, - file_path=file_path, + file_path=Path(file_path).as_posix(), class_name_alias=class_name_alias, class_name=class_name, method_name=method_name, diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 82a5b979..22107f47 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -171,8 +171,9 @@ def get_run_tmp_file(file_path: Path) -> Path: def path_belongs_to_site_packages(file_path: Path) -> bool: + file_path_resolved = file_path.resolve() site_packages = [Path(p) for p in site.getsitepackages()] - return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages) + return any(file_path_resolved.is_relative_to(site_package_path) for site_package_path in site_packages) def is_class_defined_in_file(class_name: str, file_path: Path) -> bool: diff --git a/codeflash/code_utils/coverage_utils.py b/codeflash/code_utils/coverage_utils.py index 21aa06ad..456685d4 100644 --- a/codeflash/code_utils/coverage_utils.py +++ b/codeflash/code_utils/coverage_utils.py @@ -45,7 +45,7 @@ def generate_candidates(source_code_path: Path) -> list[str]: current_path = source_code_path.parent while current_path != current_path.parent: - candidate_path = str(Path(current_path.name) / candidates[-1]) + candidate_path = (Path(current_path.name) / candidates[-1]).as_posix() candidates.append(candidate_path) current_path = current_path.parent diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index 2c8ca22a..fa917f3d 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -18,10 +18,9 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = if formatter_cmds[0] == "disabled": return return_code tmp_code = """print("hello world")""" - with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".py") as f: - f.write(tmp_code) - f.flush() - tmp_file = Path(f.name) + with tempfile.TemporaryDirectory() as tmpdir: + tmp_file = Path(tmpdir) / "test_codeflash_formatter.py" + tmp_file.write_text(tmp_code, encoding="utf-8") try: format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=exit_on_failure) except Exception: @@ -29,7 +28,7 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = "⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.", error_on_exit=True, ) - return return_code + return return_code @lru_cache(maxsize=1) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 6eac5280..9a737298 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -212,7 +212,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = args=[ ast.JoinedStr( values=[ - ast.Constant(value=f"{get_run_tmp_file(Path('test_return_values_'))}"), + ast.Constant( + value=f"{get_run_tmp_file(Path('test_return_values_')).as_posix()}" + ), ast.FormattedValue( value=ast.Name(id="codeflash_iteration", ctx=ast.Load()), conversion=-1, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index e96d1242..3d66bf8c 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -147,7 +147,7 @@ def markdown(self) -> str: """Returns the markdown representation of the code, including the file path where possible.""" return "\n".join( [ - f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```" + f"```python{':' + code_string.file_path.as_posix() if code_string.file_path else ''}\n{code_string.code.strip()}\n```" for code_string in self.code_strings ] ) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 2134cee0..50e549ca 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -101,7 +101,7 @@ def existing_tests_source_for( if greater: rows.append( [ - f"`{print_filename}::{qualified_name}`", + f"`{print_filename.as_posix()}::{qualified_name}`", f"{print_original_runtime}", f"{print_optimized_runtime}", f"⚠️{perf_gain}%", @@ -110,7 +110,7 @@ def existing_tests_source_for( else: rows.append( [ - f"`{print_filename}::{qualified_name}`", + f"`{print_filename.as_posix()}::{qualified_name}`", f"{print_original_runtime}", f"{print_optimized_runtime}", f"✅{perf_gain}%", diff --git a/codeflash/verification/instrument_codeflash_capture.py b/codeflash/verification/instrument_codeflash_capture.py index d4db6d26..d1f9816d 100644 --- a/codeflash/verification/instrument_codeflash_capture.py +++ b/codeflash/verification/instrument_codeflash_capture.py @@ -33,7 +33,7 @@ def instrument_codeflash_capture( modified_code = add_codeflash_capture_to_init( target_classes={class_parent.name}, fto_name=function_to_optimize.function_name, - tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), code=original_code, tests_root=tests_root, is_fto=True, @@ -46,7 +46,7 @@ def instrument_codeflash_capture( modified_code = add_codeflash_capture_to_init( target_classes=helper_classes, fto_name=function_to_optimize.function_name, - tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), code=original_code, tests_root=tests_root, is_fto=False, @@ -124,7 +124,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: keywords=[ ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")), ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)), - ast.keyword(arg="tests_root", value=ast.Constant(value=str(self.tests_root))), + ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())), ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)), ], ) diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 6dc1aba4..9b4ca40a 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -94,6 +94,7 @@ def run_codeflash_command( output = [] for line in process.stdout: + line = line.strip().encode("utf-8").decode("utf-8") logging.info(line.strip()) output.append(line) @@ -122,7 +123,7 @@ def build_command( ) -> list[str]: python_path = "../../../codeflash/main.py" if "code_directories" in str(cwd) else "../codeflash/main.py" - base_command = ["python", python_path, "--file", config.file_path, "--no-pr"] + base_command = ["uv", "run", "--no-project", python_path, "--file", config.file_path, "--no-pr"] if config.function_name: base_command.extend(["--function", config.function_name]) @@ -187,13 +188,14 @@ def validate_stdout_in_candidate(stdout: str, expected_in_stdout: list[str]) -> def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool: test_root = cwd / "tests" / (config.test_framework or "") clear_directory(test_root) - command = ["python", "-m", "codeflash.main", "optimize", "workload.py"] + command = ["uv", "run", "--no-project", "-m", "codeflash.main", "optimize", "workload.py"] process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() ) output = [] for line in process.stdout: + line = line.strip().encode("utf-8").decode("utf-8") logging.info(line.strip()) output.append(line) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 25200cb9..a9701b2d 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -15,6 +15,12 @@ from codeflash.code_utils.code_extractor import add_global_assignments +@pytest.fixture(scope="module") +def temp_dir(): + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + class HelperClass: def __init__(self, name): self.name = name @@ -263,7 +269,7 @@ def sort_from_another_file(arr): assert hashing_context.strip() == expected_hashing_context.strip() -def test_flavio_typed_code_helper() -> None: +def test_flavio_typed_code_helper(temp_dir: Path) -> None: code = ''' _P = ParamSpec("_P") @@ -429,33 +435,35 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: lifespan=self.__duration__, ) ''' - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: + file_path = temp_dir / "test_flavio_typed_code_helper.py" + with open(file_path, "w") as f: f.write(code) - f.flush() - file_path = Path(f.name).resolve() - opt = Optimizer( - Namespace( - project_root=file_path.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) - ) - function_to_optimize = FunctionToOptimize( - function_name="__call__", - file_path=file_path, - parents=[FunctionParent(name="_PersistentCache", type="ClassDef")], - starting_line=None, - ending_line=None, + file_path = file_path.resolve() + + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), ) + ) + function_to_optimize = FunctionToOptimize( + function_name="__call__", + file_path=file_path, + parents=[FunctionParent(name="_PersistentCache", type="ClassDef")], + starting_line=None, + ending_line=None, + ) - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) - read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code - hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context + + expected_read_write_context = """ class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): def __init__(self) -> None: ... @@ -551,7 +559,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: lifespan=self.__duration__, ) """ - expected_read_only_context = f''' + expected_read_only_context = f''' ```python:{file_path.relative_to(opt.args.project_root)} _P = ParamSpec("_P") _KEY_T = TypeVar("_KEY_T") @@ -607,7 +615,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]): __backend__: _CacheBackendT ``` ''' - expected_hashing_context = f""" + expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): @@ -644,12 +652,12 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__) ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() - assert hashing_context.strip() == expected_hashing_context.strip() + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() -def test_example_class() -> None: +def test_example_class(temp_dir: Path) -> None: code = """ class MyClass: \"\"\"A class with a helper method.\"\"\" @@ -669,7 +677,7 @@ def __repr__(self): def helper_method(self): return self.x """ - with tempfile.NamedTemporaryFile(mode="w") as f: + with (temp_dir / "test_example_class.py").open("w") as f: f.write(code) f.flush() file_path = Path(f.name).resolve() @@ -741,7 +749,7 @@ def helper_method(self): assert hashing_context.strip() == expected_hashing_context.strip() -def test_example_class_token_limit_1() -> None: +def test_example_class_token_limit_1(temp_dir: Path) -> None: docstring_filler = " ".join( ["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)] ) @@ -766,11 +774,11 @@ def __repr__(self): def helper_method(self): return self.x """ - with tempfile.NamedTemporaryFile(mode="w") as f: + file_path = temp_dir / "test_example_class_token_limit_1.py" + with open(file_path, "w") as f: f.write(code) - f.flush() - file_path = Path(f.name).resolve() - opt = Optimizer( + file_path = file_path.resolve() + opt = Optimizer( Namespace( project_root=file_path.parent.resolve(), disable_telemetry=True, @@ -781,7 +789,7 @@ def helper_method(self): test_project_root=Path().resolve(), ) ) - function_to_optimize = FunctionToOptimize( + function_to_optimize = FunctionToOptimize( function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")], @@ -789,11 +797,11 @@ def helper_method(self): ending_line=None, ) - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) - read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code - hashing_context = code_ctx.hashing_code_context + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long, so the read-only docstrings are removed. - expected_read_write_context = """ + expected_read_write_context = """ class MyClass: def __init__(self): self.x = 1 @@ -808,7 +816,7 @@ def __init__(self): def helper_method(self): return self.x """ - expected_read_only_context = f""" + expected_read_only_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: pass @@ -818,7 +826,7 @@ def __repr__(self): return "HelperClass" + str(self.x) ``` """ - expected_hashing_context = f""" + expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: @@ -831,12 +839,12 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() - assert hashing_context.strip() == expected_hashing_context.strip() + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() -def test_example_class_token_limit_2() -> None: +def test_example_class_token_limit_2(temp_dir: Path) -> None: string_filler = " ".join( ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] ) @@ -861,11 +869,11 @@ def __repr__(self): def helper_method(self): return self.x """ - with tempfile.NamedTemporaryFile(mode="w") as f: + file_path = temp_dir / "test_example_class_token_limit_2.py" + with open(file_path, "w") as f: f.write(code) - f.flush() - file_path = Path(f.name).resolve() - opt = Optimizer( + file_path = file_path.resolve() + opt = Optimizer( Namespace( project_root=file_path.parent.resolve(), disable_telemetry=True, @@ -876,7 +884,7 @@ def helper_method(self): test_project_root=Path().resolve(), ) ) - function_to_optimize = FunctionToOptimize( + function_to_optimize = FunctionToOptimize( function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")], @@ -884,11 +892,11 @@ def helper_method(self): ending_line=None, ) - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000) - read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code - hashing_context = code_ctx.hashing_code_context - # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. - expected_read_write_context = """ + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context + # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. + expected_read_write_context = """ class MyClass: def __init__(self): self.x = 1 @@ -903,8 +911,8 @@ def __init__(self): def helper_method(self): return self.x """ - expected_read_only_context = "" - expected_hashing_context = f""" + expected_read_only_context = "" + expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: @@ -917,12 +925,12 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() - assert hashing_context.strip() == expected_hashing_context.strip() + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() -def test_example_class_token_limit_3() -> None: +def test_example_class_token_limit_3(temp_dir: Path) -> None: string_filler = " ".join( ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] ) @@ -946,11 +954,11 @@ def __repr__(self): def helper_method(self): return self.x """ - with tempfile.NamedTemporaryFile(mode="w") as f: + file_path = temp_dir / "test_example_class_token_limit_3.py" + with open(file_path, "w") as f: f.write(code) - f.flush() - file_path = Path(f.name).resolve() - opt = Optimizer( + file_path = file_path.resolve() + opt = Optimizer( Namespace( project_root=file_path.parent.resolve(), disable_telemetry=True, @@ -961,7 +969,7 @@ def helper_method(self): test_project_root=Path().resolve(), ) ) - function_to_optimize = FunctionToOptimize( + function_to_optimize = FunctionToOptimize( function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")], @@ -969,11 +977,11 @@ def helper_method(self): ending_line=None, ) # In this scenario, the read-writable code is too long, so we abort. - with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): + with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) -def test_example_class_token_limit_4() -> None: +def test_example_class_token_limit_4(temp_dir: Path) -> None: string_filler = " ".join( ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] ) @@ -998,11 +1006,11 @@ def __repr__(self): def helper_method(self): return self.x """ - with tempfile.NamedTemporaryFile(mode="w") as f: + file_path = temp_dir / "test_example_class_token_limit_4.py" + with open(file_path, "w") as f: f.write(code) - f.flush() - file_path = Path(f.name).resolve() - opt = Optimizer( + file_path = file_path.resolve() + opt = Optimizer( Namespace( project_root=file_path.parent.resolve(), disable_telemetry=True, @@ -1013,7 +1021,7 @@ def helper_method(self): test_project_root=Path().resolve(), ) ) - function_to_optimize = FunctionToOptimize( + function_to_optimize = FunctionToOptimize( function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")], @@ -1022,8 +1030,8 @@ def helper_method(self): ) # In this scenario, the testgen code context is too long, so we abort. - with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"): - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"): + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) def test_repo_helper() -> None: @@ -1507,7 +1515,7 @@ def circular_dependency(self, data): assert hashing_context.strip() == expected_hashing_context.strip() -def test_indirect_init_helper() -> None: +def test_indirect_init_helper(temp_dir: Path) -> None: code = """ class MyClass: def __init__(self): @@ -1519,33 +1527,33 @@ def target_method(self): def outside_method(): return 1 """ - with tempfile.NamedTemporaryFile(mode="w") as f: + file_path = temp_dir / "test_indirect_init_helper.py" + with open(file_path, "w") as f: f.write(code) - f.flush() - file_path = Path(f.name).resolve() - opt = Optimizer( - Namespace( - project_root=file_path.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) - ) - function_to_optimize = FunctionToOptimize( - function_name="target_method", - file_path=file_path, - parents=[FunctionParent(name="MyClass", type="ClassDef")], - starting_line=None, - ending_line=None, + file_path = file_path.resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) - read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code - hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context + expected_read_write_context = """ class MyClass: def __init__(self): self.x = 1 @@ -1553,13 +1561,13 @@ def __init__(self): def target_method(self): return self.x + self.y """ - expected_read_only_context = f""" + expected_read_only_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} def outside_method(): return 1 ``` """ - expected_hashing_context = f""" + expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: @@ -1567,9 +1575,9 @@ def target_method(self): return self.x + self.y ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() - assert hashing_context.strip() == expected_hashing_context.strip() + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_direct_module_import() -> None: @@ -1653,7 +1661,7 @@ def function_to_optimize(): assert hashing_context.strip() == expected_hashing_context.strip() -def test_module_import_optimization() -> None: +def test_module_import_optimization(temp_dir: Path) -> None: main_code = """ import utility_module @@ -2102,11 +2110,11 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): CALCULATION_BACKEND = "python" ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() -def test_hashing_code_context_removes_imports_docstrings_and_init() -> None: +def test_hashing_code_context_removes_imports_docstrings_and_init(temp_dir: Path) -> None: """Test that hashing context removes imports, docstrings, and __init__ methods properly.""" code = ''' import os @@ -2144,7 +2152,7 @@ def standalone_function(): """Standalone function.""" return "standalone" ''' - with tempfile.NamedTemporaryFile(mode="w") as f: + with open(temp_dir / "test_hashing_code_context.py", "w") as f: f.write(code) f.flush() file_path = Path(f.name).resolve() @@ -2203,8 +2211,7 @@ def standalone_function(): "Docstrings should be removed from helper class methods" ) - -def test_hashing_code_context_with_nested_classes() -> None: +def test_hashing_code_context_with_nested_classes(temp_dir: Path) -> None: """Test that hashing context handles nested classes properly (should exclude them).""" code = ''' class OuterClass: @@ -2225,98 +2232,98 @@ def __init__(self): def nested_method(self): return self.nested_value ''' - with tempfile.NamedTemporaryFile(mode="w") as f: + file_path = temp_dir / "test_hashing_code_context_with_nested_classes.py" + with open(file_path, "w") as f: f.write(code) - f.flush() - file_path = Path(f.name).resolve() - opt = Optimizer( - Namespace( - project_root=file_path.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) - ) - function_to_optimize = FunctionToOptimize( - function_name="target_method", - file_path=file_path, - parents=[FunctionParent(name="OuterClass", type="ClassDef")], - starting_line=None, - ending_line=None, + file_path = file_path.resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="OuterClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) - hashing_context = code_ctx.hashing_code_context + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + hashing_context = code_ctx.hashing_code_context - # Test basic requirements - assert hashing_context.startswith("```python:") - assert hashing_context.endswith("```") - assert "target_method" in hashing_context - assert "__init__" not in hashing_context # Should not contain __init__ methods + # Test basic requirements + assert hashing_context.startswith("```python:") + assert hashing_context.endswith("```") + assert "target_method" in hashing_context + assert "__init__" not in hashing_context # Should not contain __init__ methods - # Verify nested classes are excluded from the hashing context - # The prune_cst_for_code_hashing function should not recurse into nested classes - assert "class NestedClass:" not in hashing_context # Nested class definition should not be present + # Verify nested classes are excluded from the hashing context + # The prune_cst_for_code_hashing function should not recurse into nested classes + assert "class NestedClass:" not in hashing_context # Nested class definition should not be present - # The target method will reference NestedClass, but the actual nested class definition should not be included - # The call to self.NestedClass().nested_method() should be in the target method but the nested class itself excluded - target_method_call_present = "self.NestedClass().nested_method()" in hashing_context - assert target_method_call_present, "The target method should contain the call to nested class" + # The target method will reference NestedClass, but the actual nested class definition should not be included + # The call to self.NestedClass().nested_method() should be in the target method but the nested class itself excluded + target_method_call_present = "self.NestedClass().nested_method()" in hashing_context + assert target_method_call_present, "The target method should contain the call to nested class" - # But the actual nested method definition should not be present - nested_method_definition_present = "def nested_method(self):" in hashing_context - assert not nested_method_definition_present, "Nested method definition should not be present in hashing context" + # But the actual nested method definition should not be present + nested_method_definition_present = "def nested_method(self):" in hashing_context + assert not nested_method_definition_present, "Nested method definition should not be present in hashing context" -def test_hashing_code_context_hash_consistency() -> None: +def test_hashing_code_context_hash_consistency(temp_dir: Path) -> None: """Test that the same code produces the same hash.""" code = """ class TestClass: def target_method(self): return "test" """ - with tempfile.NamedTemporaryFile(mode="w") as f: + file_path = temp_dir / "test_hashing_code_context_hash_consistency.py" + with open(file_path, "w") as f: f.write(code) - f.flush() - file_path = Path(f.name).resolve() - opt = Optimizer( - Namespace( - project_root=file_path.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) - ) - function_to_optimize = FunctionToOptimize( - function_name="target_method", - file_path=file_path, - parents=[FunctionParent(name="TestClass", type="ClassDef")], - starting_line=None, - ending_line=None, + file_path = file_path.resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="TestClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) - # Generate context twice - code_ctx1 = get_code_optimization_context(function_to_optimize, opt.args.project_root) - code_ctx2 = get_code_optimization_context(function_to_optimize, opt.args.project_root) + # Generate context twice + code_ctx1 = get_code_optimization_context(function_to_optimize, opt.args.project_root) + code_ctx2 = get_code_optimization_context(function_to_optimize, opt.args.project_root) - # Hash should be consistent - assert code_ctx1.hashing_code_context_hash == code_ctx2.hashing_code_context_hash - assert code_ctx1.hashing_code_context == code_ctx2.hashing_code_context + # Hash should be consistent + assert code_ctx1.hashing_code_context_hash == code_ctx2.hashing_code_context_hash + assert code_ctx1.hashing_code_context == code_ctx2.hashing_code_context - # Hash should be valid SHA256 - import hashlib + # Hash should be valid SHA256 + import hashlib - expected_hash = hashlib.sha256(code_ctx1.hashing_code_context.encode("utf-8")).hexdigest() - assert code_ctx1.hashing_code_context_hash == expected_hash + expected_hash = hashlib.sha256(code_ctx1.hashing_code_context.encode("utf-8")).hexdigest() + assert code_ctx1.hashing_code_context_hash == expected_hash -def test_hashing_code_context_different_code_different_hash() -> None: +def test_hashing_code_context_different_code_different_hash(temp_dir: Path) -> None: """Test that different code produces different hashes.""" code1 = """ class TestClass: @@ -2329,113 +2336,114 @@ def target_method(self): return "test2" """ - with tempfile.NamedTemporaryFile(mode="w") as f1, tempfile.NamedTemporaryFile(mode="w") as f2: + file_path1 = temp_dir / "test_file1.py" + with open(file_path1, "w") as f1: f1.write(code1) - f1.flush() - f2.write(code2) - f2.flush() + file_path1 = file_path1.resolve() - file_path1 = Path(f1.name).resolve() - file_path2 = Path(f2.name).resolve() - - opt1 = Optimizer( - Namespace( - project_root=file_path1.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) + file_path2 = temp_dir / "test_file2.py" + with open(file_path2, "w") as f2: + f2.write(code2) + file_path2 = file_path2.resolve() + + opt1 = Optimizer( + Namespace( + project_root=file_path1.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), ) - opt2 = Optimizer( - Namespace( - project_root=file_path2.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) + ) + opt2 = Optimizer( + Namespace( + project_root=file_path2.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), ) + ) - function_to_optimize1 = FunctionToOptimize( - function_name="target_method", - file_path=file_path1, - parents=[FunctionParent(name="TestClass", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - function_to_optimize2 = FunctionToOptimize( - function_name="target_method", - file_path=file_path2, - parents=[FunctionParent(name="TestClass", type="ClassDef")], - starting_line=None, - ending_line=None, - ) + function_to_optimize1 = FunctionToOptimize( + function_name="target_method", + file_path=file_path1, + parents=[FunctionParent(name="TestClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + function_to_optimize2 = FunctionToOptimize( + function_name="target_method", + file_path=file_path2, + parents=[FunctionParent(name="TestClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) - code_ctx1 = get_code_optimization_context(function_to_optimize1, opt1.args.project_root) - code_ctx2 = get_code_optimization_context(function_to_optimize2, opt2.args.project_root) + code_ctx1 = get_code_optimization_context(function_to_optimize1, opt1.args.project_root) + code_ctx2 = get_code_optimization_context(function_to_optimize2, opt2.args.project_root) - # Different code should produce different hashes - assert code_ctx1.hashing_code_context_hash != code_ctx2.hashing_code_context_hash - assert code_ctx1.hashing_code_context != code_ctx2.hashing_code_context + # Different code should produce different hashes + assert code_ctx1.hashing_code_context_hash != code_ctx2.hashing_code_context_hash + assert code_ctx1.hashing_code_context != code_ctx2.hashing_code_context -def test_hashing_code_context_format_is_markdown() -> None: +def test_hashing_code_context_format_is_markdown(temp_dir: Path) -> None: """Test that hashing context is formatted as markdown.""" code = """ class SimpleClass: def simple_method(self): return 42 """ - with tempfile.NamedTemporaryFile(mode="w") as f: + file_path = temp_dir / "test_hashing_code_context_format_is_markdown.py" + with open(file_path, "w") as f: f.write(code) - f.flush() - file_path = Path(f.name).resolve() - opt = Optimizer( - Namespace( - project_root=file_path.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) + file_path = file_path.resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), ) - function_to_optimize = FunctionToOptimize( - function_name="simple_method", - file_path=file_path, - parents=[FunctionParent(name="SimpleClass", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) - hashing_context = code_ctx.hashing_code_context + ) + function_to_optimize = FunctionToOptimize( + function_name="simple_method", + file_path=file_path, + parents=[FunctionParent(name="SimpleClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) - # Should be formatted as markdown code block - assert hashing_context.startswith("```python:") - assert hashing_context.endswith("```") + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + hashing_context = code_ctx.hashing_code_context - # Should contain the relative file path in the markdown header - relative_path = file_path.relative_to(opt.args.project_root) - assert str(relative_path) in hashing_context - - # Should contain the actual code between the markdown markers - lines = hashing_context.strip().split("\n") - assert lines[0].startswith("```python:") - assert lines[-1] == "```" - - # Code should be between the markers - code_lines = lines[1:-1] - code_content = "\n".join(code_lines) - assert "class SimpleClass:" in code_content - assert "def simple_method(self):" in code_content - assert "return 42" in code_content + # Should be formatted as markdown code block + assert hashing_context.startswith("```python:") + assert hashing_context.endswith("```") + + # Should contain the relative file path in the markdown header + relative_path = file_path.relative_to(opt.args.project_root) + assert str(relative_path) in hashing_context + + # Should contain the actual code between the markdown markers + lines = hashing_context.strip().split("\n") + assert lines[0].startswith("```python:") + assert lines[-1] == "```" + + # Code should be between the markers + code_lines = lines[1:-1] + code_content = "\n".join(code_lines) + assert "class SimpleClass:" in code_content + assert "def simple_method(self):" in code_content + assert "return 42" in code_content diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index a10f50a5..2e7a0efb 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -254,7 +254,7 @@ def test_get_run_tmp_file_reuses_temp_directory() -> None: def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None: - site_packages = [Path("/usr/local/lib/python3.9/site-packages")] + site_packages = [Path("/usr/local/lib/python3.9/site-packages").resolve()] monkeypatch.setattr(site, "getsitepackages", lambda: site_packages) file_path = Path("/usr/local/lib/python3.9/site-packages/some_package") diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 469d1be6..03fdf94e 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -42,7 +42,7 @@ def test_example_test_3(self): class MyClass: def __init__(self): self.x = 2 - print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END") + print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END") """ test_file_name = "test_stack_info_temp.py" @@ -117,7 +117,7 @@ def test_example_test_3(self): class MyClass: def __init__(self): self.x = 2 - print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END") + print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END") """ test_file_name = "test_stack_info_temp.py" @@ -181,7 +181,7 @@ def test_example_test_3(self): class MyClass: def __init__(self): self.x = 2 - print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END") + print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END") """ test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() test_file_name = "test_stack_info_temp.py" @@ -261,7 +261,7 @@ class MyClass: def __init__(self): self.x = 2 # Print out the detected test info each time we instantiate MyClass - print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END") + print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END") """ test_file_name = "test_stack_info_recursive_temp.py" @@ -343,7 +343,7 @@ def test_example_test(): class MyClass: def __init__(self): self.x = 2 - print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END") + print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END") """ test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() test_file_name = "test_stack_info_temp.py" @@ -410,10 +410,11 @@ def test_example_test_3(self): self.assertTrue(True) """ test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + tmp_dir_path = get_run_tmp_file(Path("test_return_values")) sample_code = f""" from codeflash.verification.codeflash_capture import codeflash_capture class MyClass: - @codeflash_capture(function_name="some_function", tmp_dir_path="{get_run_tmp_file(Path("test_return_values"))}", tests_root="{test_dir!s}") + @codeflash_capture(function_name="some_function", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}") def __init__(self, x=2): self.x = x """ @@ -528,6 +529,7 @@ def test_example_test_3(self): self.assertTrue(True) """ test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + tmp_dir_path = get_run_tmp_file(Path("test_return_values")) # MyClass did not have an init function, we created the init function with the codeflash_capture decorator using instrumentation sample_code = f""" from codeflash.verification.codeflash_capture import codeflash_capture @@ -536,7 +538,7 @@ def __init__(self): self.x = 2 class MyClass(ParentClass): - @codeflash_capture(function_name="some_function", tmp_dir_path="{get_run_tmp_file(Path("test_return_values"))}", tests_root="{test_dir!s}") + @codeflash_capture(function_name="some_function", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) """ @@ -648,14 +650,15 @@ def test_example_test(): """ test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + tmp_dir_path = get_run_tmp_file(Path("test_return_values")) sample_code = f""" from codeflash.verification.codeflash_capture import codeflash_capture class MyClass: @codeflash_capture( function_name="some_function", - tmp_dir_path="{get_run_tmp_file(Path("test_return_values"))}", - tests_root="{test_dir!s}" + tmp_dir_path="{tmp_dir_path.as_posix()}", + tests_root="{test_dir.as_posix()}" ) def __init__(self, x=2): self.x = x @@ -765,13 +768,14 @@ def test_helper_classes(): assert MyClass().target_function() == 6 """ test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + tmp_dir_path = get_run_tmp_file(Path("test_return_values")) original_code = f""" from codeflash.verification.codeflash_capture import codeflash_capture from code_to_optimize.tests.pytest.helper_file_1 import HelperClass1 from code_to_optimize.tests.pytest.helper_file_2 import HelperClass2, AnotherHelperClass class MyClass: - @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}" , is_fto=True) + @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}" , is_fto=True) def __init__(self): self.x = 1 @@ -785,7 +789,7 @@ def target_function(self): from codeflash.verification.codeflash_capture import codeflash_capture class HelperClass1: - @codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}", is_fto=False) + @codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}", is_fto=False) def __init__(self): self.y = 1 @@ -797,7 +801,7 @@ def helper1(self): from codeflash.verification.codeflash_capture import codeflash_capture class HelperClass2: - @codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}", is_fto=False) + @codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}", is_fto=False) def __init__(self): self.z = 2 @@ -805,7 +809,7 @@ def helper2(self): return 2 class AnotherHelperClass: - @codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}", is_fto=False) + @codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}", is_fto=False) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index fbd7d0b9..79ad1438 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -13,6 +13,11 @@ from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield Path(tmpdirname) + def test_remove_duplicate_imports(): """Test that duplicate imports are removed when should_sort_imports is True.""" original_code = "import os\nimport os\n" @@ -36,17 +41,15 @@ def test_sorting_imports(): assert new_code == "import os\nimport sys\nimport unittest\n" -def test_sort_imports_without_formatting(): +def test_sort_imports_without_formatting(temp_dir): """Test that imports are sorted when formatting is disabled and should_sort_imports is True.""" - with tempfile.NamedTemporaryFile() as tmp: - tmp.write(b"import sys\nimport unittest\nimport os\n") - tmp.flush() - tmp_path = Path(tmp.name) + temp_file = temp_dir / "test_file.py" + temp_file.write_text("import sys\nimport unittest\nimport os\n") - new_code = format_code(formatter_cmds=["disabled"], path=tmp_path) - assert new_code is not None - new_code = sort_imports(new_code) - assert new_code == "import os\nimport sys\nimport unittest\n" + new_code = format_code(formatter_cmds=["disabled"], path=temp_file) + assert new_code is not None + new_code = sort_imports(new_code) + assert new_code == "import os\nimport sys\nimport unittest\n" def test_dedup_and_sort_imports_deduplicates(): @@ -100,7 +103,7 @@ def foo(): assert actual == expected -def test_formatter_cmds_non_existent(): +def test_formatter_cmds_non_existent(temp_dir): """Test that default formatter-cmds is used when it doesn't exist in the toml.""" config_data = """ [tool.codeflash] @@ -109,24 +112,18 @@ def test_formatter_cmds_non_existent(): test-framework = "pytest" ignore-paths = [] """ + config_file = temp_dir / "config.toml" + config_file.write_text(config_data) - with tempfile.NamedTemporaryFile(suffix=".toml", delete=False) as tmp: - tmp.write(config_data.encode()) - tmp.flush() - tmp_path = Path(tmp.name) - - try: - config, _ = parse_config_file(tmp_path) - assert config["formatter_cmds"] == ["black $file"] - finally: - os.remove(tmp_path) + config, _ = parse_config_file(config_file) + assert config["formatter_cmds"] == ["black $file"] try: import black except ImportError: pytest.skip("black is not installed") - original_code = b""" + original_code = """ import os import sys def foo(): @@ -136,23 +133,21 @@ def foo(): def foo(): - return os.path.join(sys.path[0], "bar") + return os.path.join(sys.path[0], \"bar\") """ - with tempfile.NamedTemporaryFile() as tmp: - tmp.write(original_code) - tmp.flush() - tmp_path = tmp.name + temp_file = temp_dir / "test_file.py" + temp_file.write_text(original_code) - actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path)) - assert actual == expected + actual = format_code(formatter_cmds=["black $file"], path=temp_file) + assert actual == expected -def test_formatter_black(): +def test_formatter_black(temp_dir): try: import black except ImportError: pytest.skip("black is not installed") - original_code = b""" + original_code = """ import os import sys def foo(): @@ -162,23 +157,21 @@ def foo(): def foo(): - return os.path.join(sys.path[0], "bar") + return os.path.join(sys.path[0], \"bar\") """ - with tempfile.NamedTemporaryFile() as tmp: - tmp.write(original_code) - tmp.flush() - tmp_path = tmp.name + temp_file = temp_dir / "test_file.py" + temp_file.write_text(original_code) - actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path)) - assert actual == expected + actual = format_code(formatter_cmds=["black $file"], path=temp_file) + assert actual == expected -def test_formatter_ruff(): +def test_formatter_ruff(temp_dir): try: import ruff # type: ignore except ImportError: pytest.skip("ruff is not installed") - original_code = b""" + original_code = """ import os import sys def foo(): @@ -188,32 +181,29 @@ def foo(): def foo(): - return os.path.join(sys.path[0], "bar") + return os.path.join(sys.path[0], \"bar\") """ - with tempfile.NamedTemporaryFile(suffix=".py") as tmp: - tmp.write(original_code) - tmp.flush() - tmp_path = tmp.name + temp_file = temp_dir / "test_file.py" + temp_file.write_text(original_code) - actual = format_code( - formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=Path(tmp_path) - ) - assert actual == expected + actual = format_code( + formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=temp_file + ) + assert actual == expected -def test_formatter_error(): +def test_formatter_error(temp_dir): original_code = """ import os import sys def foo(): return os.path.join(sys.path[0], 'bar')""" expected = original_code - with tempfile.NamedTemporaryFile("w") as tmp: - tmp.write(original_code) - tmp.flush() - tmp_path = tmp.name - with pytest.raises(FileNotFoundError): - format_code(formatter_cmds=["exit 1"], path=Path(tmp_path)) + temp_file = temp_dir / "test_file.py" + temp_file.write_text(original_code) + + with pytest.raises(FileNotFoundError): + format_code(formatter_cmds=["exit 1"], path=temp_file) def _run_formatting_test(source_code: str, should_content_change: bool, expected = None, optimized_function: str = ""): diff --git a/tests/test_function_discovery.py b/tests/test_function_discovery.py index 291b4270..49a67ba9 100644 --- a/tests/test_function_discovery.py +++ b/tests/test_function_discovery.py @@ -21,11 +21,15 @@ def test_function_eligible_for_optimization() -> None: return a**2 """ functions_found = {} - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: - f.write(function) - f.flush() - functions_found = find_all_functions_in_file(Path(f.name)) - assert functions_found[Path(f.name)][0].function_name == "test_function_eligible_for_optimization" + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + file_path = temp_dir_path / "test_function.py" + + with file_path.open("w") as f: + f.write(function) + + functions_found = find_all_functions_in_file(file_path) + assert functions_found[file_path][0].function_name == "test_function_eligible_for_optimization" # Has no return statement function = """def test_function_not_eligible_for_optimization(): @@ -33,28 +37,40 @@ def test_function_eligible_for_optimization() -> None: print(a) """ functions_found = {} - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: - f.write(function) - f.flush() - functions_found = find_all_functions_in_file(Path(f.name)) - assert len(functions_found[Path(f.name)]) == 0 + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + file_path = temp_dir_path / "test_function.py" + + with file_path.open("w") as f: + f.write(function) + + functions_found = find_all_functions_in_file(file_path) + assert len(functions_found[file_path]) == 0 # we want to trigger an error in the function discovery function = """def test_invalid_code():""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: - f.write(function) - f.flush() - functions_found = find_all_functions_in_file(Path(f.name)) + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + file_path = temp_dir_path / "test_function.py" + + with file_path.open("w") as f: + f.write(function) + + functions_found = find_all_functions_in_file(file_path) assert functions_found == {} def test_find_top_level_function_or_method(): - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: - f.write( - """def functionA(): + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + file_path = temp_dir_path / "test_function.py" + + with file_path.open("w") as f: + f.write( + """def functionA(): def functionB(): return 5 class E: @@ -76,42 +92,48 @@ def functionE(cls, num): def non_classmethod_function(cls, name): return cls.name """ - ) - f.flush() - path_obj_name = Path(f.name) - assert inspect_top_level_functions_or_methods(path_obj_name, "functionA").is_top_level - assert not inspect_top_level_functions_or_methods(path_obj_name, "functionB").is_top_level - assert inspect_top_level_functions_or_methods(path_obj_name, "functionC", class_name="A").is_top_level - assert not inspect_top_level_functions_or_methods(path_obj_name, "functionD", class_name="A").is_top_level - assert not inspect_top_level_functions_or_methods(path_obj_name, "functionF", class_name="E").is_top_level - assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA").has_args + ) + + assert inspect_top_level_functions_or_methods(file_path, "functionA").is_top_level + assert not inspect_top_level_functions_or_methods(file_path, "functionB").is_top_level + assert inspect_top_level_functions_or_methods(file_path, "functionC", class_name="A").is_top_level + assert not inspect_top_level_functions_or_methods(file_path, "functionD", class_name="A").is_top_level + assert not inspect_top_level_functions_or_methods(file_path, "functionF", class_name="E").is_top_level + assert not inspect_top_level_functions_or_methods(file_path, "functionA").has_args staticmethod_func = inspect_top_level_functions_or_methods( - path_obj_name, "handle_record_counts", class_name=None, line_no=15 + file_path, "handle_record_counts", class_name=None, line_no=15 ) assert staticmethod_func.is_staticmethod assert staticmethod_func.staticmethod_class_name == "AirbyteEntrypoint" assert inspect_top_level_functions_or_methods( - path_obj_name, "functionE", class_name="AirbyteEntrypoint" + file_path, "functionE", class_name="AirbyteEntrypoint" ).is_classmethod assert not inspect_top_level_functions_or_methods( - path_obj_name, "non_classmethod_function", class_name="AirbyteEntrypoint" + file_path, "non_classmethod_function", class_name="AirbyteEntrypoint" ).is_top_level # needed because this will be traced with a class_name being passed # we want to write invalid code to ensure that the function discovery does not crash - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: - f.write( - """def functionA(): + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + file_path = temp_dir_path / "test_function.py" + + with file_path.open("w") as f: + f.write( + """def functionA(): """ - ) - f.flush() - path_obj_name = Path(f.name) - assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA") + ) + + assert not inspect_top_level_functions_or_methods(file_path, "functionA") def test_class_method_discovery(): - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: - f.write( - """class A: + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + file_path = temp_dir_path / "test_function.py" + + with file_path.open("w") as f: + f.write( + """class A: def functionA(): return True def functionB(): @@ -123,21 +145,20 @@ def functionB(): return False def functionA(): return True""" - ) - f.flush() + ) + test_config = TestConfig( tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() ) - path_obj_name = Path(f.name) functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, - file=path_obj_name, + file=file_path, only_get_this_function="A.functionA", test_cfg=test_config, ignore_paths=[Path("/bruh/")], - project_root=path_obj_name.parent, - module_root=path_obj_name.parent, + project_root=file_path.parent, + module_root=file_path.parent, ) assert len(functions) == 1 for file in functions: @@ -148,12 +169,12 @@ def functionA(): functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, - file=path_obj_name, + file=file_path, only_get_this_function="X.functionA", test_cfg=test_config, ignore_paths=[Path("/bruh/")], - project_root=path_obj_name.parent, - module_root=path_obj_name.parent, + project_root=file_path.parent, + module_root=file_path.parent, ) assert len(functions) == 1 for file in functions: @@ -164,12 +185,12 @@ def functionA(): functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, - file=path_obj_name, + file=file_path, only_get_this_function="functionA", test_cfg=test_config, ignore_paths=[Path("/bruh/")], - project_root=path_obj_name.parent, - module_root=path_obj_name.parent, + project_root=file_path.parent, + module_root=file_path.parent, ) assert len(functions) == 1 for file in functions: @@ -178,8 +199,12 @@ def functionA(): def test_nested_function(): - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: - f.write( + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + file_path = temp_dir_path / "test_function.py" + + with file_path.open("w") as f: + f.write( """ import copy @@ -223,28 +248,31 @@ def traverse(node_id): traverse(source_node_id) return modified_nodes """ - ) - f.flush() + ) + test_config = TestConfig( tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() ) - path_obj_name = Path(f.name) functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, - file=path_obj_name, + file=file_path, test_cfg=test_config, only_get_this_function=None, ignore_paths=[Path("/bruh/")], - project_root=path_obj_name.parent, - module_root=path_obj_name.parent, + project_root=file_path.parent, + module_root=file_path.parent, ) assert len(functions) == 1 assert functions_count == 1 - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: - f.write( + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + file_path = temp_dir_path / "test_function.py" + + with file_path.open("w") as f: + f.write( """ def outer_function(): def inner_function(): @@ -252,28 +280,31 @@ def inner_function(): return inner_function """ - ) - f.flush() + ) + test_config = TestConfig( tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() ) - path_obj_name = Path(f.name) functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, - file=path_obj_name, + file=file_path, test_cfg=test_config, only_get_this_function=None, ignore_paths=[Path("/bruh/")], - project_root=path_obj_name.parent, - module_root=path_obj_name.parent, + project_root=file_path.parent, + module_root=file_path.parent, ) assert len(functions) == 1 assert functions_count == 1 - with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: - f.write( + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + file_path = temp_dir_path / "test_function.py" + + with file_path.open("w") as f: + f.write( """ def outer_function(): def inner_function(): @@ -283,21 +314,20 @@ def another_inner_function(): pass return inner_function, another_inner_function """ - ) - f.flush() + ) + test_config = TestConfig( tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() ) - path_obj_name = Path(f.name) functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, replay_test=None, - file=path_obj_name, + file=file_path, test_cfg=test_config, only_get_this_function=None, ignore_paths=[Path("/bruh/")], - project_root=path_obj_name.parent, - module_root=path_obj_name.parent, + project_root=file_path.parent, + module_root=file_path.parent, ) assert len(functions) == 1 diff --git a/tests/test_get_code.py b/tests/test_get_code.py index 25706f70..03916859 100644 --- a/tests/test_get_code.py +++ b/tests/test_get_code.py @@ -3,13 +3,20 @@ from codeflash.code_utils.code_extractor import get_code from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent +import pytest +from pathlib import Path +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield Path(tmpdirname) -def test_get_code_function() -> None: + +def test_get_code_function(temp_dir: Path) -> None: code = """def test(self): return self._test""" - with tempfile.NamedTemporaryFile("w") as f: + with (temp_dir / "temp_file.py").open(mode="w") as f: f.write(code) f.flush() @@ -18,14 +25,14 @@ def test_get_code_function() -> None: assert contextual_dunder_methods == set() -def test_get_code_property() -> None: +def test_get_code_property(temp_dir: Path) -> None: code = """class TestClass: def __init__(self): self._test = 5 @property def test(self): return self._test""" - with tempfile.NamedTemporaryFile("w") as f: + with (temp_dir / "temp_file.py").open(mode="w") as f: f.write(code) f.flush() @@ -36,7 +43,7 @@ def test(self): assert contextual_dunder_methods == {("TestClass", "__init__")} -def test_get_code_class() -> None: +def test_get_code_class(temp_dir: Path) -> None: code = """ class TestClass: def __init__(self): @@ -54,7 +61,7 @@ def __init__(self): @property def test(self): return self._test""" - with tempfile.NamedTemporaryFile("w") as f: + with (temp_dir / "temp_file.py").open(mode="w") as f: f.write(code) f.flush() @@ -65,7 +72,7 @@ def test(self): assert contextual_dunder_methods == {("TestClass", "__init__")} -def test_get_code_bubble_sort_class() -> None: +def test_get_code_bubble_sort_class(temp_dir: Path) -> None: code = """ def hi(): pass @@ -105,7 +112,7 @@ def sorter(self, arr): arr[j + 1] = temp return arr """ - with tempfile.NamedTemporaryFile("w") as f: + with (temp_dir / "temp_file.py").open(mode="w") as f: f.write(code) f.flush() @@ -116,7 +123,7 @@ def sorter(self, arr): assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")} -def test_get_code_indent() -> None: +def test_get_code_indent(temp_dir: Path) -> None: code = """def hi(): pass @@ -168,7 +175,7 @@ def sorter(self, arr): def helper(self, arr, j): return arr[j] > arr[j + 1] """ - with tempfile.NamedTemporaryFile("w") as f: + with (temp_dir / "temp_file.py").open(mode="w") as f: f.write(code) f.flush() new_code, contextual_dunder_methods = get_code( @@ -198,7 +205,7 @@ def helper(self, arr, j): def unsorter(self, arr): return shuffle(arr) """ - with tempfile.NamedTemporaryFile("w") as f: + with (temp_dir / "temp_file.py").open(mode="w") as f: f.write(code) f.flush() new_code, contextual_dunder_methods = get_code( @@ -212,7 +219,7 @@ def unsorter(self, arr): assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")} -def test_get_code_multiline_class_def() -> None: +def test_get_code_multiline_class_def(temp_dir: Path) -> None: code = """class StatementAssignmentVariableConstantMutable( StatementAssignmentVariableMixin, StatementAssignmentVariableConstantMutableBase ): @@ -235,7 +242,7 @@ def hasVeryTrustedValue(): def computeStatement(self, trace_collection): return self, None, None """ - with tempfile.NamedTemporaryFile("w") as f: + with (temp_dir / "temp_file.py").open(mode="w") as f: f.write(code) f.flush() @@ -252,13 +259,13 @@ def computeStatement(self, trace_collection): assert contextual_dunder_methods == set() -def test_get_code_dataclass_attribute(): +def test_get_code_dataclass_attribute(temp_dir: Path) -> None: code = """@dataclass class CustomDataClass: name: str = "" data: List[int] = field(default_factory=list)""" - with tempfile.NamedTemporaryFile("w") as f: + with (temp_dir / "temp_file.py").open(mode="w") as f: f.write(code) f.flush() diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index 36359d3e..a6c30031 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -213,11 +213,12 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: lifespan=self.__duration__, ) ''' - with tempfile.NamedTemporaryFile(mode="w") as f: - f.write(code) - f.flush() - file_path = Path(f.name).resolve() - project_root_path = file_path.parent.resolve() + with tempfile.TemporaryDirectory() as tempdir: + tempdir_path = Path(tempdir) + file_path = (tempdir_path / "typed_code_helper.py").resolve() + file_path.write_text(code, encoding="utf-8") + project_root_path = tempdir_path.resolve() + project_root_path = tempdir_path.resolve() function_to_optimize = FunctionToOptimize( function_name="__call__", file_path=file_path, @@ -437,4 +438,4 @@ def sorter_deps(arr): code_context.helper_functions[0].fully_qualified_name == "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer" ) - assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap" + assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap" \ No newline at end of file diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index 7e1a20f4..a54f1060 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -123,7 +123,7 @@ def test_sort(): assert new_test is not None assert new_test.replace('"', "'") == expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") with test_path.open("w") as f: @@ -276,16 +276,16 @@ def test_sort(): fto = FunctionToOptimize( function_name="sorter", parents=[FunctionParent(name="BubbleSorter", type="ClassDef")], file_path=Path(fto_path) ) - with tempfile.NamedTemporaryFile(mode="w") as f: - f.write(code) - f.flush() + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_test_path = Path(tmpdirname) / "test_class_method_behavior_results_temp.py" + tmp_test_path.write_text(code, encoding="utf-8") success, new_test = inject_profiling_into_existing_test( - Path(f.name), [CodePosition(7, 13), CodePosition(12, 13)], fto, Path(f.name).parent, "pytest" + tmp_test_path, [CodePosition(7, 13), CodePosition(12, 13)], fto, tmp_test_path.parent, "pytest" ) assert success assert new_test.replace('"', "'") == expected.format( - module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values")) + module_path=tmp_test_path.stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix() ).replace('"', "'") tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() test_path = tests_root / "test_class_method_behavior_results_temp.py" @@ -295,7 +295,7 @@ def test_sort(): try: new_test = expected.format( module_path="code_to_optimize.tests.pytest.test_class_method_behavior_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ) with test_path.open("w") as f: @@ -486,4 +486,4 @@ def sorter(self, arr): finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) - test_path_perf.unlink(missing_ok=True) + test_path_perf.unlink(missing_ok=True) \ No newline at end of file diff --git a/tests/test_instrument_codeflash_capture.py b/tests/test_instrument_codeflash_capture.py index fe5a6bcd..df5bdbee 100644 --- a/tests/test_instrument_codeflash_capture.py +++ b/tests/test_instrument_codeflash_capture.py @@ -22,7 +22,7 @@ def target_function(self): class MyClass: - @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True) + @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True) def __init__(self): self.x = 1 @@ -86,7 +86,7 @@ def target_function(self): class MyClass(ParentClass): - @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True) + @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -128,7 +128,7 @@ def helper(self): class MyClass: - @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True) + @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True) def __init__(self): self.x = 1 @@ -184,7 +184,7 @@ def helper(self): class MyClass: - @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True) + @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True) def __init__(self): self.x = 1 @@ -197,7 +197,7 @@ def target_function(self): class HelperClass: - @codeflash_capture(function_name='HelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False) + @codeflash_capture(function_name='HelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False) def __init__(self): self.y = 1 @@ -271,7 +271,7 @@ def another_helper(self): class MyClass: - @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True) + @codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True) def __init__(self): self.x = 1 @@ -289,7 +289,7 @@ def target_function(self): class HelperClass1: - @codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False) + @codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False) def __init__(self): self.y = 1 @@ -304,7 +304,7 @@ def helper1(self): class HelperClass2: - @codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False) + @codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False) def __init__(self): self.z = 2 @@ -313,7 +313,7 @@ def helper2(self): class AnotherHelperClass: - @codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False) + @codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index a117c220..f8141086 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -173,25 +173,28 @@ def test_sort(self): self.assertEqual(codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input), list(range(5000))) codeflash_con.close() """ - with tempfile.NamedTemporaryFile(mode="w") as f: - f.write(code) - f.flush() - func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path(f.name)) + with tempfile.TemporaryDirectory() as tmpdir: + p = Path(tmpdir) + test_file_path = p / "test_bubble_sort.py" + with open(test_file_path, "w") as f: + f.write(code) + + func = FunctionToOptimize(function_name="sorter", parents=[], file_path=test_file_path) original_cwd = Path.cwd() run_cwd = Path(__file__).parent.parent.resolve() os.chdir(run_cwd) success, new_test = inject_profiling_into_existing_test( - Path(f.name), + test_file_path, [CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)], func, - Path(f.name).parent, + test_file_path.parent, "unittest", ) os.chdir(original_cwd) - assert success - assert new_test == expected.format( - module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values")) - ) + assert success + assert new_test == expected.format( + module_path=test_file_path.stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix() + ) def test_perfinjector_only_replay_test() -> None: @@ -277,21 +280,22 @@ def test_prepare_image_for_yolo(): assert compare_results(return_val_1, ret) codeflash_con.close() """ - with tempfile.NamedTemporaryFile(mode="w") as f: - f.write(code) - f.flush() + with tempfile.TemporaryDirectory() as tmpdir: + test_file_path = Path(tmpdir) / "test_prepare_image_for_yolo.py" + with open(test_file_path, "w") as f: + f.write(code) func = FunctionToOptimize(function_name="prepare_image_for_yolo", parents=[], file_path=Path("module.py")) original_cwd = Path.cwd() run_cwd = Path(__file__).parent.parent.resolve() os.chdir(run_cwd) success, new_test = inject_profiling_into_existing_test( - Path(f.name), [CodePosition(10, 14)], func, Path(f.name).parent, "pytest" + test_file_path, [CodePosition(10, 14)], func, test_file_path.parent, "pytest" ) os.chdir(original_cwd) - assert success - assert new_test == expected.format( - module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values")) - ) + assert success + assert new_test == expected.format( + module_path=test_file_path.stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix() + ) def test_perfinjector_bubble_sort_results() -> None: @@ -397,7 +401,7 @@ def test_sort(): assert new_test is not None assert new_test.replace('"', "'") == expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") success, new_perf_test = inject_profiling_into_existing_test( @@ -412,7 +416,7 @@ def test_sort(): assert new_perf_test is not None assert new_perf_test.replace('"', "'") == expected_perfonly.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") with test_path.open("w") as f: @@ -651,11 +655,11 @@ def test_sort_parametrized(input, expected_output): assert new_test is not None assert new_test.replace('"', "'") == expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") assert new_test_perf.replace('"', "'") == expected_perfonly.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # # Overwrite old test with new instrumented test @@ -924,7 +928,7 @@ def test_sort_parametrized_loop(input, expected_output): assert new_test is not None assert new_test.replace('"', "'") == expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # Overwrite old test with new instrumented test @@ -933,7 +937,7 @@ def test_sort_parametrized_loop(input, expected_output): assert new_test_perf.replace('"', "'") == expected_perf.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # Overwrite old test with new instrumented test @@ -1279,12 +1283,12 @@ def test_sort(): assert new_test_behavior is not None assert new_test_behavior.replace('"', "'") == expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") assert new_test_perf.replace('"', "'") == expected_perf.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # Overwrite old test with new instrumented test @@ -1588,11 +1592,11 @@ def test_sort(self): assert new_test_behavior is not None assert new_test_behavior.replace('"', "'") == expected.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") assert new_test_perf.replace('"', "'") == expected_perf.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # # Overwrite old test with new instrumented test @@ -1845,13 +1849,13 @@ def test_sort(self, input, expected_output): assert new_test_behavior is not None assert new_test_behavior.replace('"', "'") == expected_behavior.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") assert new_test_perf is not None assert new_test_perf.replace('"', "'") == expected_perf.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # @@ -2111,11 +2115,11 @@ def test_sort(self): assert new_test_behavior is not None assert new_test_behavior.replace('"', "'") == expected_behavior.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") assert new_test_perf.replace('"', "'") == expected_perf.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # # # Overwrite old test with new instrumented test @@ -2370,11 +2374,11 @@ def test_sort(self, input, expected_output): assert new_test_behavior is not None assert new_test_behavior.replace('"', "'") == expected_behavior.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") assert new_test_perf.replace('"', "'") == expected_perf.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # # Overwrite old test with new instrumented test @@ -2671,7 +2675,7 @@ def test_class_name_A_function_name(): assert success assert new_test is not None assert new_test.replace('"', "'") == expected.format( - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), module_path="tests.pytest.test_class_function_instrumentation_temp", ).replace('"', "'") @@ -2742,7 +2746,7 @@ def test_common_tags_1(): assert new_test is not None assert new_test.replace('"', "'") == expected.format( module_path="tests.pytest.test_wrong_function_instrumentation_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") finally: test_path.unlink(missing_ok=True) @@ -2805,7 +2809,7 @@ def test_sort(): assert new_test is not None assert new_test.replace('"', "'") == expected.format( module_path="tests.pytest.test_conditional_instrumentation_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") finally: test_path.unlink(missing_ok=True) @@ -2882,7 +2886,7 @@ def test_sort(): assert success formatted_expected = expected.format( module_path="tests.pytest.test_perfinjector_bubble_sort_results_temp", - tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))), + tmp_dir_path=str(get_run_tmp_file(Path("test_return_values")).as_posix()), ) assert new_test is not None assert new_test.replace('"', "'") == formatted_expected.replace('"', "'") @@ -2960,24 +2964,25 @@ def test_code_replacement10() -> None: """ ) - with tempfile.NamedTemporaryFile(mode="w") as f: - f.write(code) - f.flush() + with tempfile.TemporaryDirectory() as tmpdir: + test_file_path = Path(tmpdir) / "test_code_replacement10.py" + with open(test_file_path, "w") as f: + f.write(code) func = FunctionToOptimize( function_name="get_code_optimization_context", parents=[FunctionParent("Optimizer", "ClassDef")], - file_path=Path(f.name), + file_path=test_file_path, ) original_cwd = Path.cwd() run_cwd = Path(__file__).parent.parent.resolve() os.chdir(run_cwd) success, new_test = inject_profiling_into_existing_test( - Path(f.name), [CodePosition(22, 28), CodePosition(28, 28)], func, Path(f.name).parent, "pytest" + test_file_path, [CodePosition(22, 28), CodePosition(28, 28)], func, test_file_path.parent, "pytest" ) os.chdir(original_cwd) assert success assert new_test.replace('"', "'") == expected.replace('"', "'").format( - module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values")) + module_path=test_file_path.stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix() ) @@ -3042,7 +3047,7 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time): assert new_test is not None assert new_test.replace('"', "'") == expected.format( module_path="code_to_optimize.tests.pytest.test_time_correction_instrumentation_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # Overwrite old test with new instrumented test with test_path.open("w") as f: @@ -3161,7 +3166,7 @@ def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time): assert new_test is not None assert new_test.replace('"', "'") == expected.format( module_path="code_to_optimize.tests.unittest.test_time_correction_instrumentation_unittest_temp", - tmp_dir_path=get_run_tmp_file(Path("test_return_values")), + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), ).replace('"', "'") # Overwrite old test with new instrumented test with test_path.open("w") as f: @@ -3216,4 +3221,4 @@ def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time): assert math.isclose(test_result.runtime, ((i % 2) + 1) * 100_000_000, rel_tol=0.01) finally: - test_path.unlink(missing_ok=True) + test_path.unlink(missing_ok=True) \ No newline at end of file diff --git a/tests/test_shell_utils.py b/tests/test_shell_utils.py index 51d48406..d7cb987c 100644 --- a/tests/test_shell_utils.py +++ b/tests/test_shell_utils.py @@ -38,6 +38,12 @@ def setUp(self): self.test_rc_path = "test_shell_rc" self.api_key = "cf-1234567890abcdef" os.environ["SHELL"] = "/bin/bash" # Set a default shell for testing + + # Set up platform-specific export syntax + if os.name == "nt": # Windows + self.api_key_export = f'set CODEFLASH_API_KEY={self.api_key}' + else: # Unix-like systems + self.api_key_export = f'export CODEFLASH_API_KEY="{self.api_key}"' def tearDown(self): """Cleanup the temporary shell configuration file after testing.""" @@ -50,7 +56,7 @@ def test_valid_api_key(self): with patch("codeflash.code_utils.shell_utils.get_shell_rc_path") as mock_get_shell_rc_path: mock_get_shell_rc_path.return_value = self.test_rc_path with patch( - "builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n') + "builtins.open", mock_open(read_data=f'{self.api_key_export}\n') ) as mock_file: self.assertEqual(read_api_key_from_shell_config(), self.api_key) mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8") @@ -91,9 +97,15 @@ def test_malformed_api_key_export(self, mock_get_shell_rc_path): def test_multiple_api_key_exports(self, mock_get_shell_rc_path): """Test with multiple API key exports.""" mock_get_shell_rc_path.return_value = self.test_rc_path + if os.name == "nt": # Windows + first_export = 'set CODEFLASH_API_KEY=cf-firstkey' + second_export = f'set CODEFLASH_API_KEY={self.api_key}' + else: # Unix-like systems + first_export = 'export CODEFLASH_API_KEY="cf-firstkey"' + second_export = f'export CODEFLASH_API_KEY="{self.api_key}"' with patch( "builtins.open", - mock_open(read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'), + mock_open(read_data=f'{first_export}\n{second_export}\n'), ): self.assertEqual(read_api_key_from_shell_config(), self.api_key) @@ -103,7 +115,7 @@ def test_api_key_export_with_extra_text(self, mock_get_shell_rc_path): mock_get_shell_rc_path.return_value = self.test_rc_path with patch( "builtins.open", - mock_open(read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n'), + mock_open(read_data=f'# Setting API Key\n{self.api_key_export}\n# Done\n'), ): self.assertEqual(read_api_key_from_shell_config(), self.api_key) @@ -111,7 +123,7 @@ def test_api_key_export_with_extra_text(self, mock_get_shell_rc_path): def test_api_key_in_comment(self, mock_get_shell_rc_path): """Test with API key export in a comment.""" mock_get_shell_rc_path.return_value = self.test_rc_path - with patch("builtins.open", mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n')): + with patch("builtins.open", mock_open(read_data=f'# {self.api_key_export}\n')): self.assertIsNone(read_api_key_from_shell_config()) @patch("codeflash.code_utils.shell_utils.get_shell_rc_path") diff --git a/tests/test_trace_benchmarks.py b/tests/test_trace_benchmarks.py index 2d5a3c6e..cc037c0b 100644 --- a/tests/test_trace_benchmarks.py +++ b/tests/test_trace_benchmarks.py @@ -154,7 +154,7 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__(): from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle functions = ['compute_and_sort', 'sorter'] -trace_file_path = r"{output_file}" +trace_file_path = r"{output_file.as_posix()}" def test_code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort(): for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_compute_and_sort", function_name="compute_and_sort", file_path=r"{process_and_bubble_sort_path}", num_to_get=100): @@ -196,6 +196,8 @@ def test_trace_multithreaded_benchmark() -> None: "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name") function_calls = cursor.fetchall() + conn.close() + # Assert the length of function calls assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}" function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) @@ -204,9 +206,9 @@ def test_trace_multithreaded_benchmark() -> None: assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0] - assert total_time > 0.0 - assert function_time > 0.0 - assert percent > 0.0 + assert total_time >= 0.0 + assert function_time >= 0.0 + assert percent >= 0.0 bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix() # Expected function calls diff --git a/tests/test_tracer.py b/tests/test_tracer.py index f9c2ae23..2b29f5d8 100644 --- a/tests/test_tracer.py +++ b/tests/test_tracer.py @@ -72,8 +72,8 @@ def trace_config(self) -> Generator[Path, None, None]: with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False, dir=temp_dir) as f: f.write(f""" [tool.codeflash] -module-root = "{current_dir}" -tests-root = "{tests_dir}" +module-root = "{current_dir.as_posix()}" +tests-root = "{tests_dir.as_posix()}" test-framework = "pytest" ignore-paths = [] """)