Skip to content

Commit

Permalink
Updated solution: better API
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiannis128 committed Feb 4, 2025
1 parent 2d6371c commit e94e342
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 33 deletions.
22 changes: 8 additions & 14 deletions esbmc_ai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,31 @@

# Enables arrow key functionality for input(). Do not remove import.
import readline
import argparse

from esbmc_ai.addon_loader import AddonLoader
from esbmc_ai.commands.user_chat_command import UserChatCommand
from esbmc_ai.solution import Solution
from esbmc_ai.verifier_runner import VerifierRunner
from esbmc_ai.verifiers.esbmc import ESBMC

_ = readline

from esbmc_ai.command_runner import CommandRunner
from esbmc_ai.commands.fix_code_command import FixCodeCommandResult


import argparse

from esbmc_ai import Config
from esbmc_ai import __author__, __version__
from esbmc_ai.solution import get_solution

from esbmc_ai.commands import (
ChatCommand,
FixCodeCommand,
HelpCommand,
ExitCommand,
FixCodeCommandResult,
)

from esbmc_ai.loading_widget import BaseLoadingWidget, LoadingWidget
from esbmc_ai.chats import UserChat
from esbmc_ai.logging import printv, printvv
from esbmc_ai.ai_models import _ai_model_names

_ = readline


help_command: HelpCommand = HelpCommand()
fix_code_command: FixCodeCommand = FixCodeCommand()
exit_command: ExitCommand = ExitCommand()
Expand Down Expand Up @@ -104,15 +98,15 @@ def _run_command_mode(command: ChatCommand, args: argparse.Namespace) -> None:
# Basic fix mode: Supports only 1 file repair.
case fix_code_command.command_name:
print("Reading source code...")
get_solution().load_source_files(Config().filenames)
solution: Solution = Solution(Config().filenames)
print(f"Running ESBMC with {Config().get_value('verifier.esbmc.params')}\n")

anim: BaseLoadingWidget = (
LoadingWidget()
if Config().get_value("loading_hints")
else BaseLoadingWidget()
)
for source_file in get_solution().files:
for source_file in solution.files:
result: FixCodeCommandResult = (
UserChatCommand._execute_fix_code_command_one_file(
fix_code_command,
Expand Down Expand Up @@ -238,7 +232,7 @@ def main() -> None:
# ===========================================
# Check if command is called and call it.
# If not, then continue to user mode.
if args.command != None:
if args.command is not None:
command = args.command
command_names: list[str] = command_runner.command_names
if command in command_names:
Expand Down
42 changes: 23 additions & 19 deletions esbmc_ai/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,22 @@ def apply_line_patch(source_code: str, patch: str, start: int, end: int) -> str:
lines = lines[:start] + [patch] + lines[end + 1 :]
return "\n".join(lines)

def __init__(self, file_path: Path, content: str) -> None:
@staticmethod
def load(file_path: Path, base_path: Path) -> "SourceFile":
with open(base_path / file_path, "r") as file:
return SourceFile(file_path, base_path, file.read())

def __init__(self, file_path: Path, base_path: Path, content: str) -> None:
self.file_path: Path = file_path
self.base_path: Path = base_path
self.content: str = content
self.verifier_output: Optional[VerifierOutput] = None

@property
def abs_path(self) -> Path:
"""Returns the abs path"""
return self.base_path / self.file_path

@property
def file_extension(self) -> str:
"""Returns the file extension to the file."""
Expand Down Expand Up @@ -126,16 +137,19 @@ def from_dir(path: Path) -> "Solution":

def __init__(
self,
files: Optional[list[Path]] = None,
base_dir: Optional[Path] = None,
files: list[Path],
base_dir: Path = Path(getcwd()),
) -> None:
"""Creates a new solution with a base directory."""
self.base_dir: Path = base_dir if base_dir else Path(getcwd())
files = files if files else []
self.base_dir: Path = base_dir

self._files: list[SourceFile] = []

for file_path in files:
# Get the relative path to the base dir.
rel_path: Path = file_path.relative_to(self.base_dir)
with open(file_path, "r") as file:
self._files.append(SourceFile(file_path, file.read()))
self._files.append(SourceFile(rel_path, self.base_dir, file.read()))

@property
def files(self) -> list[SourceFile]:
Expand All @@ -147,7 +161,7 @@ def files_mapped(self) -> dict[str, SourceFile]:
"""Will return the files mapped to their directory. Returns by value."""
return {str(source_file.file_path): source_file for source_file in self._files}

def get_files(self, included_ext: list[str]) -> list[SourceFile]:
def get_files_by_ext(self, included_ext: list[str]) -> list[SourceFile]:
"""Gets the files that are only specified in the included extensions. File
extensions that have a . prefix are trimmed so they still work."""
return [s for s in self.files if s.file_extension.strip(".") in included_ext]
Expand All @@ -166,7 +180,7 @@ def save_solution(self, path: Path) -> "Solution":
base_dir_path: Path = path
new_file_paths: list[Path] = []
for source_file in self.files:
relative_path: Path = source_file.file_path.relative_to(self.base_dir)
relative_path: Path = source_file.file_path
new_path: Path = base_dir_path / relative_path
# Write new file
new_file_paths.append(new_path)
Expand Down Expand Up @@ -199,14 +213,4 @@ def load_source_file(self, file_path: Path) -> None:
not be loaded."""
assert file_path
with open(file_path, "r") as file:
self._files.append(SourceFile(file_path, file.read()))


# Define a global solution (is not required to be used)

_solution: Solution = Solution()


def get_solution() -> Solution:
"""Returns the global default solution object."""
return _solution
self._files.append(SourceFile(file_path, self.base_dir, file.read()))

0 comments on commit e94e342

Please sign in to comment.