Skip to content

Commit

Permalink
improved client connection retry logic (#68)
Browse files Browse the repository at this point in the history
* improved client connection retry logic

* fixing tests
  • Loading branch information
wjcunningham7 authored Oct 20, 2023
1 parent d16460a commit 8632dfb
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 41 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ jobs:
shell: python

- name: Run tests
env:
COVALENT_PLUGIN_LOAD: false
run: |
PYTHONPATH=$PWD/tests pytest -m "not functional_tests" -vv tests/ --cov=covalent_ssh_plugin
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### Changed

- Improved connection retry logic

## [0.22.0] - 2023-09-20

### Changed
Expand Down
39 changes: 26 additions & 13 deletions covalent_ssh_plugin/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
create_unique_workdir: Optional[bool] = None,
poll_freq: int = 15,
do_cleanup: bool = True,
retry_connect: bool = True,
) -> None:

remote_cache = (
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
)

self.do_cleanup = do_cleanup
self.retry_connect = retry_connect

ssh_key_file = ssh_key_file or get_config("executors.ssh.ssh_key_file")
self.ssh_key_file = str(Path(ssh_key_file).expanduser().resolve())
Expand Down Expand Up @@ -168,10 +170,10 @@ def _write_function_files(
"",
f"with open('{remote_function_file}', 'rb') as f_in:",
" fn, args, kwargs = pickle.load(f_in)",
" current_dir = os.getcwd()",
" try:",
f" Path({current_remote_workdir}).mkdir(parents=True, exist_ok=True)",
" current_dir = os.getcwd()",
f" os.chdir({current_remote_workdir})",
f" Path('{current_remote_workdir}').mkdir(parents=True, exist_ok=True)",
f" os.chdir('{current_remote_workdir}')",
" result = fn(*args, **kwargs)",
" except Exception as e:",
" exception = e",
Expand Down Expand Up @@ -240,21 +242,32 @@ async def _client_connect(self) -> Tuple[bool, asyncssh.SSHClientConnection]:
ssh_success = False
conn = None
if os.path.exists(self.ssh_key_file):
try:
conn = await asyncssh.connect(
self.hostname,
username=self.username,
client_keys=[self.ssh_key_file],
known_hosts=None,
)
retries = 6 if self.retry_connect else 1
for _ in range(retries):
try:
conn = await asyncssh.connect(
self.hostname,
username=self.username,
client_keys=[self.ssh_key_file],
known_hosts=None,
)

ssh_success = True
except (socket.gaierror, ValueError, TimeoutError, ConnectionRefusedError) as e:
app_log.error(e)

if conn is not None:
break

ssh_success = True
except (socket.gaierror, ValueError, TimeoutError) as e:
app_log.error(e)
await asyncio.sleep(5)

if conn is None and not self.run_local_on_ssh_fail:
raise RuntimeError("Could not connect to remote host.")

else:
message = f"no SSH key file found at {self.ssh_key_file}. Cannot connect to host."
app_log.error(message)
raise RuntimeError(message)

return ssh_success, conn

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def git_match_requirement(req):

setup_info = {
"name": "covalent-ssh-plugin",
"packages": find_packages("."),
"packages": find_packages(exclude=["tests", "tests.*"]),
"version": version,
"maintainer": "Agnostiq",
"url": "https://github.com/AgnostiqHQ/covalent-ssh-plugin",
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
aiofiles>=22.1.0
flake8==3.9.2
isort==5.7.0
mock==4.0.3
Expand Down
54 changes: 27 additions & 27 deletions tests/ssh_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tempfile
from unittest.mock import AsyncMock, MagicMock, mock_open, patch

import aiofiles
import pytest

from covalent_ssh_plugin import SSHExecutor
Expand Down Expand Up @@ -73,37 +74,41 @@ async def test_on_ssh_fail(mocker):
"""Test that the process runs locally upon connection errors."""

mocker.patch("covalent_ssh_plugin.ssh.get_config", side_effect=get_config_mock)

executor = SSHExecutor(
username="user",
hostname="host",
ssh_key_file="key_file",
run_local_on_ssh_fail=True,
)
mocker.patch("covalent_ssh_plugin.SSHExecutor._validate_credentials", return_value=True)

def simple_task(x):
return x**2

executor.node_id = (0,)
executor.dispatch_id = (0,)
mocker.patch("covalent_ssh_plugin.SSHExecutor._validate_credentials", return_value=True)
result = await executor.run(
function=simple_task,
args=[5],
kwargs={},
task_metadata={"dispatch_id": -1, "node_id": -1},
)
assert result == 25
async with aiofiles.tempfile.NamedTemporaryFile("w") as f:
executor = SSHExecutor(
username="user",
hostname="host",
ssh_key_file=f.name,
run_local_on_ssh_fail=True,
retry_connect=False,
)

executor.node_id = (0,)
executor.dispatch_id = (0,)

executor.run_local_on_ssh_fail = False
with pytest.raises(RuntimeError):
result = await executor.run(
function=simple_task,
args=[5],
kwargs={},
task_metadata={"dispatch_id": -1, "node_id": -1},
)

assert result == 25

executor.run_local_on_ssh_fail = False
with pytest.raises(RuntimeError):
result = await executor.run(
function=simple_task,
args=[5],
kwargs={},
task_metadata={"dispatch_id": -1, "node_id": -1},
)


@pytest.mark.asyncio
async def test_client_connect(mocker):
Expand All @@ -115,16 +120,11 @@ async def test_client_connect(mocker):
username="user",
hostname="host",
ssh_key_file="non-existent_key",
retry_connect=False,
)

connected, _ = await executor._client_connect()
assert connected is False

# Patch to fake existence of valid SSH keyfile. Connection should still fail due to
# the invalide username/hostname.
mocker.patch("builtins.open", mock_open(read_data="data"))
connected, _ = await executor._client_connect()
assert connected is False
with pytest.raises(RuntimeError):
connected, _ = await executor._client_connect()

# Patch to make call to paramiko.SSHClient.connect not fail with incorrect user/host/keyfile.
mocker.patch("os.path.exists", return_value=True)
Expand Down

0 comments on commit 8632dfb

Please sign in to comment.