Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# > External packages
import nox

# > Ignore certain dirs
ASSETS_DIR = "src/oet/assets"

# > Making sure Nox session only see their packages and not any globally installed packages.
os.environ.pop("PYTHONPATH", None)
# > Hiding any virtual environments from outside.
Expand Down Expand Up @@ -37,7 +40,7 @@
@nox.session(tags=["static_check"])
def type_check(session):
session.install(".[type-check]")
session.run("mypy")
session.run("mypy", "--exclude", f"^{ASSETS_DIR}/")


# //////////////////////////////////////////////////
Expand All @@ -47,7 +50,7 @@ def type_check(session):
def remove_unused_imports(session):
session.install(".[lint]")
# > Sorting imports with ruff instead of isort
session.run("ruff", "check", "--fix", "--select", "F401")
session.run("ruff", "check", "--fix", "--select", "F401", "--exclude", ASSETS_DIR)


# //////////////////////////////////////////
Expand All @@ -57,7 +60,7 @@ def remove_unused_imports(session):
def sort_imports(session):
session.install(".[lint]")
# > Sorting imports with ruff instead of isort
session.run("ruff", "check", "--fix", "--select", "I")
session.run("ruff", "check", "--fix", "--select", "I", "--exclude", ASSETS_DIR)


# ////////////////////////////////////////
Expand All @@ -66,7 +69,7 @@ def sort_imports(session):
@nox.session(tags=["style", "static_check"])
def lint(session):
session.install(".[lint]")
session.run("ruff", "check", "--fix")
session.run("ruff", "check", "--fix", "--exclude", ASSETS_DIR)


# //////////////////////////////////////////
Expand All @@ -76,7 +79,7 @@ def lint(session):
def format_code(session):
# Installs the project + the "lint" extra into this nox venv using pip
session.install(".[lint]")
session.run("ruff", "format")
session.run("ruff", "format", "--exclude", ASSETS_DIR)


# ////////////////////////////////////////////////////
Expand All @@ -85,7 +88,7 @@ def format_code(session):
@nox.session(tags=["static_check"])
def spell_check(session):
session.install(".[spell-check]")
session.run("codespell", "src/oet")
session.run("codespell", "src/oet", "--skip", ASSETS_DIR)


# //////////////////////////////////////////////
Expand All @@ -94,4 +97,4 @@ def spell_check(session):
@nox.session(tags=["static_check"], default=True)
def dead_code(session):
session.install(".[dead-code]")
session.run("vulture")
session.run("vulture", "src", "--exclude", ASSETS_DIR)
115 changes: 115 additions & 0 deletions src/oet/core/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
Utilities used in the test suite
"""

import multiprocessing as mp
import subprocess
from enum import StrEnum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

if TYPE_CHECKING:
from multiprocessing.queues import Queue

WATER = [
("O", 0.0000, 0.0000, 0.0000),
Expand Down Expand Up @@ -195,3 +201,112 @@ def clear_files(basename: str) -> None:
for f in dir_path.glob(basename + "*"):
if f.is_file():
f.unlink() # remove file


def _worker(
fn: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], q: "Queue[bool]"
) -> None:
"""
Helper for executing a function.

Parameters
----------
fn: Callable[..., T]
Callable function that is executed.
args: tuple[Any]
Any positional arguments that are given to the function call.
kwargs: dict[str, Any]
Any keyword arguments that are given to the function call.
q: Queue[bool]
Queue used to put the function call.
"""
try:
# Call the function
_ = fn(*args, **kwargs)
# Don't check what it did, just return ok, if the function didn't crash
q.put(True)
except Exception:
q.put(False)


class TimeoutCallError(StrEnum):
"""Possible errors that are returned by TimeoutCall"""

# Function timed out
TIMEOUT = "timeout"
# Function crashed
CRASH = "crash"
# General error
ERROR = "error"


class TimeoutCall:
"""
Class for calling a function with a certain timeout.
Useful for functions that, e.g., download files.
Doesn't return the result of the function as it might not be pickled.
"""

def __init__(self, fn: Callable[..., Any]) -> None:
"""
Initialization of the class.

Parameters
----------
fn: Callable[..., T]
Callable function that is executed.
"""
self.fn = fn
self.timed_out = False

def __call__(
self, *args: Any, timeout: float = 10, **kwargs: Any
) -> tuple[bool, TimeoutCallError | None]:
"""
Execute the function set in __init__ with the timeout defined there.

Parameters
----------
args: Any
Any positional arguments that are given to the function call.
timeout: float, default: 10 sec.
Timeout in sec.
kwargs: Any
Any keyword arguments that are given to the function call.

Returns
-------
bool
True, if everything was ok. False otherwise.
TimeoutCallError | None
Either the error type if failed or None.
"""

# Start process and wait the timeout
q: "Queue[bool]" = mp.Queue()
p: mp.Process = mp.Process(target=_worker, args=(self.fn, args, kwargs, q))
p.start()
p.join(timeout)

# Check if there was any error
if p.exitcode not in (0, None):
return False, TimeoutCallError.CRASH

# Check if the process is still alive. If yes, it has timed out.
if p.is_alive():
p.terminate()
p.join()
return False, TimeoutCallError.TIMEOUT

# Check if the worker provides the correct result
try:
status_ok = q.get(timeout=1)
except Exception:
return False, TimeoutCallError.CRASH

# Check if there was a general error
if not status_ok:
return False, TimeoutCallError.ERROR

# If everything went well, return True and no Error.
return True, None
1 change: 1 addition & 0 deletions tests/aimnet2/test_aiment2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
write_xyz_file,
)

# Path to the scripts, adjust if needed.
aimnet2_script_path = ROOT_DIR / "../../bin/oet_client"
aimnet2_server_path = ROOT_DIR / "../../bin/oet_server"
# Default ID and port of server. Change if needed
Expand Down
1 change: 1 addition & 0 deletions tests/aimnet2/test_aiment2_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
write_xyz_file,
)

# Path to the script, adjust if needed.
aimnet2_script_path = ROOT_DIR / "../../bin/oet_aimnet2"


Expand Down
1 change: 1 addition & 0 deletions tests/g-xtb/test_gxtb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
write_xyz_file,
)

# Path to the scripts, adjust if needed.
gxtb_script_path = ROOT_DIR / "../../bin/oet_gxtb"
# Leave uma_executable_path empty, if gxtb from system path should be called
gxtb_executable_path = ""
Expand Down
30 changes: 29 additions & 1 deletion tests/mlatom/test_mlatom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
from oet.core.test_utilities import (
OH,
WATER,
TimeoutCall,
TimeoutCallError,
get_filenames,
read_result_file,
run_wrapper,
write_input_file,
write_xyz_file,
)

# Path to the script, adjust if needed.
mlatom_script_path = ROOT_DIR / "../../bin/oet_mlatom"
# Default maximum time (in sec) to download the model files if not present
timeout = 300
# Leave mlatom_executable_path empty, if mlatom from system path should be called
mlatom_executable_path = ""

Expand All @@ -33,7 +38,30 @@ class MLatomTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Force download / initialization of ANI-1ccx once
torchani.models.ANI1ccx(periodic_table_index=True)
print("Checking the model files and downloading them if necessary.")
# Make a timeout call to avoid hanging forever
get_ani1ccx_timeout = TimeoutCall(fn=torchani.models.ANI1ccx)
ok, payload = get_ani1ccx_timeout(timeout=timeout, periodic_table_index=True)
# Check if the model files could not be loaded
if not ok:
# Timeout
if payload == TimeoutCallError.TIMEOUT:
print(
"Loading the model files timed out. "
"Please check your internet connection and consider increasing the time before timing out."
)
raise unittest.SkipTest("Timed out.")
# General errors and crashes
elif payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR:
print(
"Loading the model files failed. Make sure that "
"the virtual environment with MLAtom installed is active."
)
raise unittest.SkipTest("Loading failed.")
# Unresolved error
else:
print("Could not load the model files.")
raise unittest.SkipTest("Loading failed.")

def test_H2O_engrad(self):
xyz_file, input_file, engrad_out, output_file = get_filenames("H2O")
Expand Down
1 change: 1 addition & 0 deletions tests/mopac/test_mopac.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
write_xyz_file,
)

# Path to the script, adjust if needed.
mopac_script_path = ROOT_DIR / "../../bin/oet_mopac"
# Leave moppac_executable_path empty, if mopac from system path should be called
mopac_executable_path = ""
Expand Down
54 changes: 53 additions & 1 deletion tests/uma/test_uma_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,55 @@
import unittest

from oet import ROOT_DIR
from oet.calculator.uma import DEFAULT_CACHE_DIR, UmaCalc
from oet.core.test_utilities import (
OH,
WATER,
TimeoutCall,
TimeoutCallError,
get_filenames,
read_result_file,
run_wrapper,
write_input_file,
write_xyz_file,
)

# Path to the scripts, adjust if needed.
uma_script_path = ROOT_DIR / "../../bin/oet_client"
uma_server_path = ROOT_DIR / "../../bin/oet_server"
# Default maximum time (in sec) to download the model files if not present
timeout = 600
# Default ID and port of server. Change if needed
id_port = "127.0.0.1:9000"
# UMA model to use
uma_model = "uma-s-1p1"


def cache_model_files(
basemodel: str, param: str = "omol", device: str = "cpu", cache_dir: str = DEFAULT_CACHE_DIR
) -> None:
"""
Wrapper to set up an UMA calculator that downloads the model files into the same cache-directory used for actual oet calculations.

basemodel: str
Basemodel used to calculate the test cases
param: str, default: omol
Parameter set.
device str, default: cpu
Device used for the calculations.
cache_dir: str, default: DEFAULT_CACHE_DIR
The cache directory used to store the model data.
"""
calculator = UmaCalc()
calculator.set_calculator(param=param, basemodel=basemodel, device=device, cache_dir=cache_dir)


def run_uma(inputfile: str, output_file: str) -> None:
run_wrapper(
inputfile=inputfile,
script_path=uma_script_path,
outfile=output_file,
args=["--bind", id_port],
args=["--bind", id_port, "--model", uma_model],
)


Expand All @@ -36,6 +63,31 @@ def setUpClass(cls):
"""
Test starting the server
"""
# Pre-download UMA model files
print("Checking the model files and downloading them if necessary.")
# Make a timeout call to avoid hanging forever
get_pretrained_mlip_timeout = TimeoutCall(fn=cache_model_files)
ok, payload = get_pretrained_mlip_timeout(uma_model, timeout=timeout)
# Check if the model files could not be loaded
if not ok:
# Timeout
if payload == TimeoutCallError.TIMEOUT:
print(
"Loading the model files timed out. "
"Please check your internet connection and consider increasing the time before timing out."
)
raise unittest.SkipTest("Timed out.")
# General errors and crashes
elif payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR:
print(
"Loading the model files failed. Make sure that "
"the virtual environment with UMA installed is active."
)
raise unittest.SkipTest("Loading failed.")
# Unresolved error
else:
print("Could not load the model files.")
raise unittest.SkipTest("Loading failed.")
print("Starting the server. A detailed server log can be found on file server.out")
with open("server.out", "a") as f:
cls.server = subprocess.Popen(
Expand Down
Loading