diff --git a/cli/cli.py b/cli/cli.py index 4861347..dd3755e 100644 --- a/cli/cli.py +++ b/cli/cli.py @@ -12,18 +12,21 @@ from .constants import ( CONSOLE, - ALL_PKGS, VALID, CROSS, WORKDIR, TOOLS_DIR, UNCHANGED, + tla2tools, + community_modules, + tlc, + repl, ) from .utils import AliasedGroup, error_handler # Create a mapping for faster lookup -pkg_map = {p.name: p for p in ALL_PKGS} +pkg_map = {p.name: p for p in [tla2tools, community_modules]} @click.group( @@ -52,8 +55,8 @@ def cli(ctx: click.Context, manifest: Path) -> None: TOOLS_DIR.mkdir(exist_ok=True) if ctx.invoked_subcommand is None: - if manifest is None: - pkg_map["TLA2Tools"].tools["REPL"].start() + if manifest is None and repl.is_available(): + repl.start() else: from .models import Manifest @@ -90,7 +93,7 @@ def tla_package_list() -> None: table.add_column("Version", justify="center") table.add_column("Up-to-date", justify="center") - for pkg in ALL_PKGS: + for pkg in pkg_map.values(): if pkg.is_installed: table.add_row( pkg.name, @@ -110,7 +113,7 @@ def tla_package_list() -> None: metavar="PACKAGE_SPEC", nargs=-1, type=str, - default=[p.name for p in ALL_PKGS], + default=[p.name for p in pkg_map.values()], ) @error_handler def tla_package_install(pkg_specs: list[str]) -> None: @@ -165,7 +168,7 @@ def tla_package_install(pkg_specs: list[str]) -> None: metavar="PACKAGE_NAME", nargs=-1, type=str, - default=[p.name for p in ALL_PKGS], + default=[p.name for p in pkg_map.values()], ) @error_handler def tla_package_upgrade(pkg_names: list[str]) -> None: @@ -207,7 +210,7 @@ def tla_package_upgrade(pkg_names: list[str]) -> None: metavar="PACKAGE_NAME", nargs=-1, type=str, - default=[p.name for p in ALL_PKGS], + default=[p.name for p in pkg_map.values()], ) @error_handler def tla_package_uninstall(pkg_names: list[str]) -> None: @@ -283,6 +286,9 @@ def tla_package_uninstall(pkg_names: list[str]) -> None: type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=Path), help="Path to the TLC configuration file (.cfg). If not provided, it is assumed to be alongside the module file with a .cfg extension.", ) +@click.option( + "--save-states", "-s", is_flag=True, help="Wheither to save the state graph." +) @error_handler def tla_model_check( module_path: Path, @@ -291,6 +297,7 @@ def tla_model_check( max_heap_size: str, community_modules: bool, external_module: list[Path], + save_states: bool, ) -> None: """ Run the TLC model checker on a given TLA+ module file. @@ -301,7 +308,6 @@ def tla_model_check( f"External module '{ext_module}' must be a .jar file." ) - tlc = pkg_map["TLA2Tools"].tools["TLC"] model_path = model_path if model_path else module_path.with_suffix(".cfg") tlc.start( module_path, @@ -310,6 +316,7 @@ def tla_model_check( max_heap_size=max_heap_size, community_modules=community_modules, external_modules=external_module, + save_states=save_states, ) diff --git a/cli/constants.py b/cli/constants.py index 5197367..69cfeee 100644 --- a/cli/constants.py +++ b/cli/constants.py @@ -5,12 +5,18 @@ from rich.console import Console from rich.logging import RichHandler -from .packages import Package, TLA2Tools, CommunityModules +from .packages import TLA2Tools, CommunityModules +from .tools import TLC, REPL +VALID = "[green]✓[/green]" +CROSS = "[red]✗[/red]" +UNCHANGED = "[yellow]~[/yellow]" + PROJECT_DIR = Path(__file__).parent.parent WORKDIR = PROJECT_DIR / ".tla" TOOLS_DIR = WORKDIR / "tools" +RUN_DATA_DIR = WORKDIR / "data" CONSOLE = Console() @@ -22,11 +28,17 @@ ) LOGGER = logging.getLogger("rich") -ALL_PKGS: list[Package] = [ - TLA2Tools(location=TOOLS_DIR), - CommunityModules(location=TOOLS_DIR), -] -VALID = "[green]✓[/green]" -CROSS = "[red]✗[/red]" -UNCHANGED = "[yellow]~[/yellow]" +tla2tools = TLA2Tools(location=TOOLS_DIR, logger=LOGGER, console=CONSOLE) +community_modules = CommunityModules(location=TOOLS_DIR, logger=LOGGER, console=CONSOLE) + +repl = REPL(main_class="tlc2.REPL", pkg=tla2tools, logger=LOGGER, console=CONSOLE) + +tlc = TLC( + main_class="tlc2.TLC", + data_path=RUN_DATA_DIR, + community_modules=community_modules, + pkg=tla2tools, + logger=LOGGER, + console=CONSOLE, +) diff --git a/cli/models.py b/cli/models.py index cf99909..c864ac7 100644 --- a/cli/models.py +++ b/cli/models.py @@ -2,11 +2,12 @@ from datetime import timedelta from pathlib import Path -from typing import List, Optional, Literal +from typing import Optional, Literal +from typing_extensions import Self -from pydantic import BaseModel, FilePath, Field +from pydantic import BaseModel, DirectoryPath, Field, model_validator -from .constants import ALL_PKGS, CONSOLE +from .constants import tlc, CONSOLE class Checks(BaseModel): @@ -43,11 +44,11 @@ class Configuration(BaseModel): Attributes: max_heap_size: Maximum heap size allocated for the JVM in Mio. - cores: Number of CPU cores allocated for model checking. + workers: Number workers used for model checking. """ - max_heap_size: Optional[str] = None - cores: int = 1 + max_heap_size: Optional[str] = "1G" + workers: int = 1 class Model(BaseModel): @@ -65,11 +66,11 @@ class Model(BaseModel): """ name: str - path: FilePath + path: Path runtime: Optional[timedelta] = None type: Literal["explicit", "symbolic"] mode: Literal["exhaustive", "simulation"] - configuration: Optional[Configuration] = None + configuration: Configuration = Configuration() checks: Checks @@ -83,7 +84,7 @@ class Dependencies(BaseModel): """ community_modules: bool = False - additional_modules: List[FilePath] = Field(default_factory=list) + additional_modules: list[Path] = Field(default_factory=list) class Module(BaseModel): @@ -97,9 +98,9 @@ class Module(BaseModel): models: List of models associated with the module. """ - path: FilePath - dependencies: Dependencies - models: List[Model] = [] + path: Path + dependencies: Dependencies = Dependencies() + models: list[Model] = [] class Manifest(BaseModel): @@ -112,12 +113,13 @@ class Manifest(BaseModel): modules: List of TLA+ modules to be processed. """ + base_path: DirectoryPath tlc_version: Optional[str] = None total_duration: Optional[timedelta] = None - modules: List[Module] + modules: list[Module] @classmethod - def load_manifest(cls, path: FilePath) -> "Manifest": + def load_manifest(cls, path: Path) -> "Manifest": """Load a manifest from a JSON file. Args: @@ -127,8 +129,7 @@ def load_manifest(cls, path: FilePath) -> "Manifest": An instance of the Manifest class populated with data from the file. """ with path.open("r") as f: - data = replace_paths(json.loads(f.read()), path.parent) - return cls(**data) + return cls(base_path=path.parent, **json.loads(f.read())) def process(self) -> dict[Path, bool]: """Process all modules and their models as specified in the manifest. @@ -139,8 +140,6 @@ def process(self) -> dict[Path, bool]: """ processing_results = {} - tlc = [p for p in ALL_PKGS if p.name == "TLA2Tools"][0].tools["TLC"] - for module in self.modules: for model in module.models: if model.type == "explicit": @@ -151,9 +150,9 @@ def process(self) -> dict[Path, bool]: external_modules=module.dependencies.additional_modules, # timeout=model.runtime, # mode=model.mode, - workers=model.configuration.cores, + workers=model.configuration.workers, max_heap_size=model.configuration.max_heap_size, - show_log=False, + show_log=True, ) if tlc_run.success == model.checks.success: assertions = ( @@ -198,23 +197,22 @@ def process(self) -> dict[Path, bool]: ) return processing_results + def _check_modify_relative_path(self, path: Path) -> Path: + full_path = path + if not full_path.is_absolute(): + full_path = self.base_path / full_path + if not full_path.is_file(): + raise ValueError(f"File not found: {path}.") + return full_path -def replace_paths(data, base_path: Path): - """Recursively replace 'path' fields in a nested data structure with absolute paths. - - Args: - data: The nested data structure (dicts and lists). - base_path: The base path to prepend to relative paths. - - Returns: - The modified data structure with updated paths. - """ - if isinstance(data, dict): - return { - k: replace_paths(base_path / v if k == "path" else v, base_path) - for k, v in data.items() - } - elif isinstance(data, list): - return [replace_paths(item, base_path) for item in data] - else: - return data + @model_validator(mode="after") + def check_relative_paths(self) -> Self: + for module in self.modules: + module.path = self._check_modify_relative_path(module.path) + module.dependencies.additional_modules = [ + self._check_modify_relative_path(additional_module) + for additional_module in module.dependencies.additional_modules + ] + for model in module.models: + model.path = self._check_modify_relative_path(model.path) + return self diff --git a/cli/packages/__init__.py b/cli/packages/__init__.py index f426e07..df1ce45 100644 --- a/cli/packages/__init__.py +++ b/cli/packages/__init__.py @@ -1,3 +1,4 @@ -from .github import Package, TLA2Tools, CommunityModules +from .base import Package +from .github import GithubReleasePackage, TLA2Tools, CommunityModules -__all__ = ["Package", "TLA2Tools", "CommunityModules"] +__all__ = ["Package", "GithubReleasePackage", "TLA2Tools", "CommunityModules"] diff --git a/cli/packages/base.py b/cli/packages/base.py index 6378f9b..d9c7273 100644 --- a/cli/packages/base.py +++ b/cli/packages/base.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod +from logging import Logger from pathlib import Path +from rich.console import Console from packaging.version import Version -from ..tools import Tool - class Package(ABC): """ @@ -15,7 +15,9 @@ class Package(ABC): location: The installation location of the package. """ - def __init__(self, name: str, location: Path) -> None: + def __init__( + self, name: str, location: Path, logger: Logger, console: Console + ) -> None: """ Initialize a new Package instance. @@ -25,11 +27,8 @@ def __init__(self, name: str, location: Path) -> None: """ self.name = name self.location = location - - @property - @abstractmethod - def tools(self) -> dict[str, Tool]: - pass + self.logger = logger + self.console = console @property @abstractmethod diff --git a/cli/packages/github.py b/cli/packages/github.py index 16674ab..81bf732 100644 --- a/cli/packages/github.py +++ b/cli/packages/github.py @@ -4,15 +4,16 @@ import requests from abc import ABC, abstractmethod +from logging import Logger from pathlib import Path from github import Github from github.Repository import Repository from github.GithubException import UnknownObjectException +from rich.console import Console from packaging.version import parse as parse_version, Version from .base import Package -from ..tools import Tool, REPL, TLC class GithubReleasePackage(Package, ABC): @@ -32,6 +33,8 @@ def __init__( location: Path, repo_name: str, asset_name: str, + logger: Logger, + console: Console, prerelease: bool = False, ) -> None: """ @@ -44,7 +47,7 @@ def __init__( asset_name: The name of the asset to download. prerelease: Whether or not to consider pre-releases when searching for the latest version. """ - super().__init__(name, location) + super().__init__(name=name, location=location, logger=logger, console=console) self.repo_name = repo_name self.asset_name = asset_name self.versions = [] @@ -189,33 +192,18 @@ class TLA2Tools(GithubReleasePackage): Concrete class for the TLA2Tools package. """ - def __init__(self, location: Path) -> None: + def __init__(self, location: Path, logger: Logger, console: Console) -> None: asset_name = "tla2tools.jar" super().__init__( name="TLA2Tools", location=(location / asset_name), repo_name="tlaplus/tlaplus", asset_name=asset_name, + logger=logger, + console=console, prerelease=True, ) - @property - def tools(self) -> dict[str, Tool]: - return { - "REPL": REPL( - classpath=self.location, - main_class="tlc2.REPL", - tla2tools_version=self.current_version, - ), - "TLC": TLC( - classpath=self.location, - main_class="tlc2.TLC", - run_path=self.location.parent, - community_modules_classpath=self.location.parent - / "CommunityModules-deps.jar", - ), - } - def version_to_tag(self, version: Version) -> str: """ Format a version for use in GitHub API calls. @@ -234,19 +222,17 @@ class CommunityModules(GithubReleasePackage): Concrete class for the CommunityModules package. """ - def __init__(self, location: Path) -> None: + def __init__(self, location: Path, logger: Logger, console: Console) -> None: asset_name = "CommunityModules-deps.jar" super().__init__( name="CommunityModules", location=(location / asset_name), repo_name="tlaplus/CommunityModules", asset_name=asset_name, + logger=logger, + console=console, ) - @property - def tools(self) -> dict[str, Tool]: - return {} - def version_to_tag(self, version: Version) -> str: """ Format a version for use in GitHub API calls. diff --git a/cli/tools/base.py b/cli/tools/base.py index 92cca4d..3b55723 100644 --- a/cli/tools/base.py +++ b/cli/tools/base.py @@ -1,6 +1,19 @@ from abc import ABC +from logging import Logger + +from rich.console import Console + +from ..packages import Package class Tool(ABC): - def __init__(self, name: str) -> None: + def __init__( + self, name: str, pkg: Package, logger: Logger, console: Console + ) -> None: self.name = name + self.pkg = pkg + self.logger = logger + self.console = console + + def is_available(self) -> bool: + return self.pkg.is_installed diff --git a/cli/tools/java.py b/cli/tools/java.py index 9187850..4caaa0a 100644 --- a/cli/tools/java.py +++ b/cli/tools/java.py @@ -1,11 +1,14 @@ -from abc import ABC - +from logging import Logger from pathlib import Path from typing import Optional +from rich.console import Console + +from ..packages import Package +from .base import Tool -class JavaClassTool(ABC): +class JavaClassTool(Tool): """Base class for tools that wrap Java command-line applications. Attributes: @@ -16,8 +19,16 @@ class JavaClassTool(ABC): parallel_gc: Whether to enable parallel garbage collection. """ - def __init__(self, name: str, classpath: Path, main_class: str) -> None: - self.name = name + def __init__( + self, + name: str, + classpath: Path, + main_class: str, + pkg: Package, + logger: Logger, + console: Console, + ) -> None: + super().__init__(name, pkg, logger, console) self.classpath = [classpath] self.main_class = main_class self.max_heap_size = "4G" diff --git a/cli/tools/repl.py b/cli/tools/repl.py index ef72078..5266c01 100644 --- a/cli/tools/repl.py +++ b/cli/tools/repl.py @@ -1,10 +1,13 @@ import subprocess import sys -from pathlib import Path +from logging import Logger +from typing import cast +from rich.console import Console from packaging.version import Version +from ..packages import GithubReleasePackage from .java import JavaClassTool @@ -12,25 +15,43 @@ class REPL(JavaClassTool): """TLA+ REPL tool for interactive constant expressions evaluation.""" def __init__( - self, classpath: Path, main_class: str, tla2tools_version: Version + self, + main_class: str, + pkg: GithubReleasePackage, + logger: Logger, + console: Console, ) -> None: - super().__init__(name="REPL", classpath=classpath, main_class=main_class) - self.tla2tools_version = tla2tools_version + super().__init__( + name="REPL", + classpath=pkg.location, + main_class=main_class, + pkg=pkg, + logger=logger, + console=console, + ) + + def is_available(self) -> bool: + if not super().is_available(): + return False + elif cast(Version, self.pkg.current_version) >= Version("v1.8.0"): + return True + return False def start(self) -> None: """Starts the REPL if the version of TLA2Tools supports it.""" - if self.tla2tools_version < Version("v1.8.0"): - return - process = subprocess.Popen( - self.get_java_command(), - stdin=sys.stdin, - stdout=sys.stdout, - stderr=subprocess.PIPE, - text=True, - bufsize=1, - ) - process.wait() - if process.returncode != 0: - raise RuntimeError( - process.stderr.read() if process.stderr else "REPL failed." + try: + process = subprocess.Popen( + self.get_java_command(), + stdin=sys.stdin, + stdout=sys.stdout, + stderr=subprocess.PIPE, + text=True, + bufsize=1, ) + process.wait() + if process.returncode != 0: + raise RuntimeError( + process.stderr.read() if process.stderr else "REPL failed." + ) + except KeyboardInterrupt: + pass diff --git a/cli/tools/tlc.py b/cli/tools/tlc.py index 26a418a..95c6ea2 100644 --- a/cli/tools/tlc.py +++ b/cli/tools/tlc.py @@ -4,9 +4,13 @@ from dataclasses import dataclass, asdict from datetime import datetime, timedelta +from logging import Logger from pathlib import Path from typing import Optional, Any +from rich.console import Console + +from ..packages import GithubReleasePackage from .java import JavaClassTool @@ -77,14 +81,23 @@ class TLC(JavaClassTool): def __init__( self, - classpath: Path, main_class: str, - run_path: Path, - community_modules_classpath: Path, + data_path: Path, + community_modules: GithubReleasePackage, + pkg: GithubReleasePackage, + logger: Logger, + console: Console, ) -> None: - super().__init__(name="TLC", classpath=classpath, main_class=main_class) - self.base_path = run_path - self.community_modules_classpath = community_modules_classpath + super().__init__( + name="TLC", + classpath=pkg.location, + main_class=main_class, + pkg=pkg, + logger=logger, + console=console, + ) + self.base_path = data_path + self.community_modules = community_modules def create_run_dir(self) -> Path: """Creates a new directory for the TLC run.""" @@ -103,6 +116,7 @@ def start( max_heap_size: str, community_modules: bool, external_modules: list[Path], + save_states: bool = False, show_log: bool = True, ) -> TLCRun: """Runs TLC and returns the results. @@ -114,6 +128,7 @@ def start( max_heap_size: Maximum heap size for the JVM. community_modules: Whether to include community modules. external_modules: list of paths to external modules. + save_states: Weither to save the state space graph in a Graphviz .dot file. show_log: Whether to print TLC output to the console. Returns: @@ -126,13 +141,17 @@ def start( self.parallel_gc = True self.max_heap_size = max_heap_size if community_modules: - self.classpath.append(self.community_modules_classpath) + self.classpath.append(self.community_modules.location) if external_modules: self.classpath.extend(external_modules) cmd = self.get_java_command( ["-workers", str(workers), "-config", str(model_path), str(module_path)] ) + + if save_states: + cmd.extend(["-dump", "dot", "states"]) + tlc_run.states_file = run_dir / "states.dot" process = subprocess.Popen( cmd, stdout=subprocess.PIPE, @@ -147,7 +166,7 @@ def start( tlc_output = [] for line in process.stdout: if show_log: - print(line.strip()) + print(line.replace("\n", "")) tlc_output.append(line) process.wait()