Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@

from .constants import (
CONSOLE,
ALL_PKGS,
VALID,
CROSS,
WORKDIR,
TOOLS_DIR,
UNCHANGED,
tla2tools,
community_modules,
tlc,
repl,
)
from .utils import AliasedGroup, error_handler


# Create a mapping for faster lookup
pkg_map = {p.name: p for p in ALL_PKGS}
pkg_map = {p.name: p for p in [tla2tools, community_modules]}


@click.group(
Expand Down Expand Up @@ -52,8 +55,8 @@ def cli(ctx: click.Context, manifest: Path) -> None:
TOOLS_DIR.mkdir(exist_ok=True)

if ctx.invoked_subcommand is None:
if manifest is None:
pkg_map["TLA2Tools"].tools["REPL"].start()
if manifest is None and repl.is_available():
repl.start()
else:
from .models import Manifest

Expand Down Expand Up @@ -90,7 +93,7 @@ def tla_package_list() -> None:
table.add_column("Version", justify="center")
table.add_column("Up-to-date", justify="center")

for pkg in ALL_PKGS:
for pkg in pkg_map.values():
if pkg.is_installed:
table.add_row(
pkg.name,
Expand All @@ -110,7 +113,7 @@ def tla_package_list() -> None:
metavar="PACKAGE_SPEC",
nargs=-1,
type=str,
default=[p.name for p in ALL_PKGS],
default=[p.name for p in pkg_map.values()],
)
@error_handler
def tla_package_install(pkg_specs: list[str]) -> None:
Expand Down Expand Up @@ -165,7 +168,7 @@ def tla_package_install(pkg_specs: list[str]) -> None:
metavar="PACKAGE_NAME",
nargs=-1,
type=str,
default=[p.name for p in ALL_PKGS],
default=[p.name for p in pkg_map.values()],
)
@error_handler
def tla_package_upgrade(pkg_names: list[str]) -> None:
Expand Down Expand Up @@ -207,7 +210,7 @@ def tla_package_upgrade(pkg_names: list[str]) -> None:
metavar="PACKAGE_NAME",
nargs=-1,
type=str,
default=[p.name for p in ALL_PKGS],
default=[p.name for p in pkg_map.values()],
)
@error_handler
def tla_package_uninstall(pkg_names: list[str]) -> None:
Expand Down Expand Up @@ -283,6 +286,9 @@ def tla_package_uninstall(pkg_names: list[str]) -> None:
type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=Path),
help="Path to the TLC configuration file (.cfg). If not provided, it is assumed to be alongside the module file with a .cfg extension.",
)
@click.option(
"--save-states", "-s", is_flag=True, help="Wheither to save the state graph."
)
@error_handler
def tla_model_check(
module_path: Path,
Expand All @@ -291,6 +297,7 @@ def tla_model_check(
max_heap_size: str,
community_modules: bool,
external_module: list[Path],
save_states: bool,
) -> None:
"""
Run the TLC model checker on a given TLA+ module file.
Expand All @@ -301,7 +308,6 @@ def tla_model_check(
f"External module '{ext_module}' must be a .jar file."
)

tlc = pkg_map["TLA2Tools"].tools["TLC"]
model_path = model_path if model_path else module_path.with_suffix(".cfg")
tlc.start(
module_path,
Expand All @@ -310,6 +316,7 @@ def tla_model_check(
max_heap_size=max_heap_size,
community_modules=community_modules,
external_modules=external_module,
save_states=save_states,
)


Expand Down
28 changes: 20 additions & 8 deletions cli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from rich.console import Console
from rich.logging import RichHandler

from .packages import Package, TLA2Tools, CommunityModules
from .packages import TLA2Tools, CommunityModules
from .tools import TLC, REPL


VALID = "[green]✓[/green]"
CROSS = "[red]✗[/red]"
UNCHANGED = "[yellow]~[/yellow]"

PROJECT_DIR = Path(__file__).parent.parent
WORKDIR = PROJECT_DIR / ".tla"
TOOLS_DIR = WORKDIR / "tools"
RUN_DATA_DIR = WORKDIR / "data"

CONSOLE = Console()

Expand All @@ -22,11 +28,17 @@
)
LOGGER = logging.getLogger("rich")

ALL_PKGS: list[Package] = [
TLA2Tools(location=TOOLS_DIR),
CommunityModules(location=TOOLS_DIR),
]

VALID = "[green]✓[/green]"
CROSS = "[red]✗[/red]"
UNCHANGED = "[yellow]~[/yellow]"
tla2tools = TLA2Tools(location=TOOLS_DIR, logger=LOGGER, console=CONSOLE)
community_modules = CommunityModules(location=TOOLS_DIR, logger=LOGGER, console=CONSOLE)

repl = REPL(main_class="tlc2.REPL", pkg=tla2tools, logger=LOGGER, console=CONSOLE)

tlc = TLC(
main_class="tlc2.TLC",
data_path=RUN_DATA_DIR,
community_modules=community_modules,
pkg=tla2tools,
logger=LOGGER,
console=CONSOLE,
)
76 changes: 37 additions & 39 deletions cli/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from datetime import timedelta
from pathlib import Path
from typing import List, Optional, Literal
from typing import Optional, Literal
from typing_extensions import Self

from pydantic import BaseModel, FilePath, Field
from pydantic import BaseModel, DirectoryPath, Field, model_validator

from .constants import ALL_PKGS, CONSOLE
from .constants import tlc, CONSOLE


class Checks(BaseModel):
Expand Down Expand Up @@ -43,11 +44,11 @@ class Configuration(BaseModel):

Attributes:
max_heap_size: Maximum heap size allocated for the JVM in Mio.
cores: Number of CPU cores allocated for model checking.
workers: Number workers used for model checking.
"""

max_heap_size: Optional[str] = None
cores: int = 1
max_heap_size: Optional[str] = "1G"
workers: int = 1


class Model(BaseModel):
Expand All @@ -65,11 +66,11 @@ class Model(BaseModel):
"""

name: str
path: FilePath
path: Path
runtime: Optional[timedelta] = None
type: Literal["explicit", "symbolic"]
mode: Literal["exhaustive", "simulation"]
configuration: Optional[Configuration] = None
configuration: Configuration = Configuration()
checks: Checks


Expand All @@ -83,7 +84,7 @@ class Dependencies(BaseModel):
"""

community_modules: bool = False
additional_modules: List[FilePath] = Field(default_factory=list)
additional_modules: list[Path] = Field(default_factory=list)


class Module(BaseModel):
Expand All @@ -97,9 +98,9 @@ class Module(BaseModel):
models: List of models associated with the module.
"""

path: FilePath
dependencies: Dependencies
models: List[Model] = []
path: Path
dependencies: Dependencies = Dependencies()
models: list[Model] = []


class Manifest(BaseModel):
Expand All @@ -112,12 +113,13 @@ class Manifest(BaseModel):
modules: List of TLA+ modules to be processed.
"""

base_path: DirectoryPath
tlc_version: Optional[str] = None
total_duration: Optional[timedelta] = None
modules: List[Module]
modules: list[Module]

@classmethod
def load_manifest(cls, path: FilePath) -> "Manifest":
def load_manifest(cls, path: Path) -> "Manifest":
"""Load a manifest from a JSON file.

Args:
Expand All @@ -127,8 +129,7 @@ def load_manifest(cls, path: FilePath) -> "Manifest":
An instance of the Manifest class populated with data from the file.
"""
with path.open("r") as f:
data = replace_paths(json.loads(f.read()), path.parent)
return cls(**data)
return cls(base_path=path.parent, **json.loads(f.read()))

def process(self) -> dict[Path, bool]:
"""Process all modules and their models as specified in the manifest.
Expand All @@ -139,8 +140,6 @@ def process(self) -> dict[Path, bool]:
"""
processing_results = {}

tlc = [p for p in ALL_PKGS if p.name == "TLA2Tools"][0].tools["TLC"]

for module in self.modules:
for model in module.models:
if model.type == "explicit":
Expand All @@ -151,9 +150,9 @@ def process(self) -> dict[Path, bool]:
external_modules=module.dependencies.additional_modules,
# timeout=model.runtime,
# mode=model.mode,
workers=model.configuration.cores,
workers=model.configuration.workers,
max_heap_size=model.configuration.max_heap_size,
show_log=False,
show_log=True,
)
if tlc_run.success == model.checks.success:
assertions = (
Expand Down Expand Up @@ -198,23 +197,22 @@ def process(self) -> dict[Path, bool]:
)
return processing_results

def _check_modify_relative_path(self, path: Path) -> Path:
full_path = path
if not full_path.is_absolute():
full_path = self.base_path / full_path
if not full_path.is_file():
raise ValueError(f"File not found: {path}.")
return full_path

def replace_paths(data, base_path: Path):
"""Recursively replace 'path' fields in a nested data structure with absolute paths.

Args:
data: The nested data structure (dicts and lists).
base_path: The base path to prepend to relative paths.

Returns:
The modified data structure with updated paths.
"""
if isinstance(data, dict):
return {
k: replace_paths(base_path / v if k == "path" else v, base_path)
for k, v in data.items()
}
elif isinstance(data, list):
return [replace_paths(item, base_path) for item in data]
else:
return data
@model_validator(mode="after")
def check_relative_paths(self) -> Self:
for module in self.modules:
module.path = self._check_modify_relative_path(module.path)
module.dependencies.additional_modules = [
self._check_modify_relative_path(additional_module)
for additional_module in module.dependencies.additional_modules
]
for model in module.models:
model.path = self._check_modify_relative_path(model.path)
return self
5 changes: 3 additions & 2 deletions cli/packages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .github import Package, TLA2Tools, CommunityModules
from .base import Package
from .github import GithubReleasePackage, TLA2Tools, CommunityModules

__all__ = ["Package", "TLA2Tools", "CommunityModules"]
__all__ = ["Package", "GithubReleasePackage", "TLA2Tools", "CommunityModules"]
15 changes: 7 additions & 8 deletions cli/packages/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path

from rich.console import Console
from packaging.version import Version

from ..tools import Tool


class Package(ABC):
"""
Expand All @@ -15,7 +15,9 @@ class Package(ABC):
location: The installation location of the package.
"""

def __init__(self, name: str, location: Path) -> None:
def __init__(
self, name: str, location: Path, logger: Logger, console: Console
) -> None:
"""
Initialize a new Package instance.

Expand All @@ -25,11 +27,8 @@ def __init__(self, name: str, location: Path) -> None:
"""
self.name = name
self.location = location

@property
@abstractmethod
def tools(self) -> dict[str, Tool]:
pass
self.logger = logger
self.console = console

@property
@abstractmethod
Expand Down
Loading