diff --git a/.gitignore b/.gitignore index c359fed3..165be9fb 100644 --- a/.gitignore +++ b/.gitignore @@ -154,3 +154,4 @@ cython_debug/ # setuptools_scm version file _version.py +.aider* diff --git a/examples/nodeps/.gitignore b/examples/nodeps/.gitignore new file mode 100644 index 00000000..23e52e05 --- /dev/null +++ b/examples/nodeps/.gitignore @@ -0,0 +1,4 @@ +dummy*.txt +remote_job.json +stderr.txt +*.log diff --git a/examples/nodeps/cleanup.sh b/examples/nodeps/cleanup.sh new file mode 100755 index 00000000..d4694e5e --- /dev/null +++ b/examples/nodeps/cleanup.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +rm -f {dummy8,dummy24,dummy48}/{dummy.txt,dummy2.txt,remote_job.json,stderr.txt} diff --git a/examples/nodeps/dirlist1.txt b/examples/nodeps/dirlist1.txt new file mode 100644 index 00000000..fa58d712 --- /dev/null +++ b/examples/nodeps/dirlist1.txt @@ -0,0 +1 @@ +dummy8 diff --git a/examples/nodeps/dirlist3.txt b/examples/nodeps/dirlist3.txt new file mode 100644 index 00000000..9c39fe92 --- /dev/null +++ b/examples/nodeps/dirlist3.txt @@ -0,0 +1,3 @@ +dummy8 +dummy24 +dummy48 diff --git a/examples/nodeps/dummy24/rjm_downloads.txt b/examples/nodeps/dummy24/rjm_downloads.txt new file mode 100644 index 00000000..7bb5ba72 --- /dev/null +++ b/examples/nodeps/dummy24/rjm_downloads.txt @@ -0,0 +1,2 @@ +dummy.txt +dummy2.txt diff --git a/examples/nodeps/dummy24/rjm_uploads.txt b/examples/nodeps/dummy24/rjm_uploads.txt new file mode 100644 index 00000000..965efab7 --- /dev/null +++ b/examples/nodeps/dummy24/rjm_uploads.txt @@ -0,0 +1 @@ +run.sl diff --git a/examples/nodeps/dummy24/run.sl b/examples/nodeps/dummy24/run.sl new file mode 100644 index 00000000..00c46efb --- /dev/null +++ b/examples/nodeps/dummy24/run.sl @@ -0,0 +1,10 @@ +#!/bin/bash +#SBATCH --job-name=testfuncx +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --mem=128 +#SBATCH --time=00:10:00 + +touch dummy.txt +sleep 50 +touch dummy2.txt diff --git a/examples/nodeps/dummy48/rjm_downloads.txt b/examples/nodeps/dummy48/rjm_downloads.txt new file mode 100644 index 00000000..7bb5ba72 --- /dev/null +++ b/examples/nodeps/dummy48/rjm_downloads.txt @@ -0,0 +1,2 @@ +dummy.txt +dummy2.txt diff --git a/examples/nodeps/dummy48/rjm_uploads.txt b/examples/nodeps/dummy48/rjm_uploads.txt new file mode 100644 index 00000000..965efab7 --- /dev/null +++ b/examples/nodeps/dummy48/rjm_uploads.txt @@ -0,0 +1 @@ +run.sl diff --git a/examples/nodeps/dummy48/run.sl b/examples/nodeps/dummy48/run.sl new file mode 100644 index 00000000..4eab7025 --- /dev/null +++ b/examples/nodeps/dummy48/run.sl @@ -0,0 +1,10 @@ +#!/bin/bash +#SBATCH --job-name=testfuncx +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --mem=128 +#SBATCH --time=00:10:00 + +touch dummy.txt +sleep 70 +touch dummy2.txt diff --git a/examples/nodeps/dummy8/rjm_downloads.txt b/examples/nodeps/dummy8/rjm_downloads.txt new file mode 100644 index 00000000..7bb5ba72 --- /dev/null +++ b/examples/nodeps/dummy8/rjm_downloads.txt @@ -0,0 +1,2 @@ +dummy.txt +dummy2.txt diff --git a/examples/nodeps/dummy8/rjm_uploads.txt b/examples/nodeps/dummy8/rjm_uploads.txt new file mode 100644 index 00000000..965efab7 --- /dev/null +++ b/examples/nodeps/dummy8/rjm_uploads.txt @@ -0,0 +1 @@ +run.sl diff --git a/examples/nodeps/dummy8/run.sl b/examples/nodeps/dummy8/run.sl new file mode 100644 index 00000000..6940dcf4 --- /dev/null +++ b/examples/nodeps/dummy8/run.sl @@ -0,0 +1,10 @@ +#!/bin/bash +#SBATCH --job-name=testfuncx +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --mem=128 +#SBATCH --time=00:05:00 + +touch dummy.txt +sleep 30 +touch dummy2.txt diff --git a/examples/nodeps/run1.sh b/examples/nodeps/run1.sh new file mode 100755 index 00000000..7d288e0e --- /dev/null +++ b/examples/nodeps/run1.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +set -e + +./cleanup.sh +echo "" +echo "================================================================================" +echo "Start of rjm_batch_submit..." +echo "" +rjm_batch_submit -f dirlist1.txt -ll debug -n +echo "" +echo "================================================================================" +echo "Start of rjm_batch_wait..." +echo "" +rjm_batch_wait -f dirlist1.txt -ll debug -n diff --git a/examples/nodeps/run3.sh b/examples/nodeps/run3.sh new file mode 100755 index 00000000..89131235 --- /dev/null +++ b/examples/nodeps/run3.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +set -e + +./cleanup.sh +rjm_batch_submit -f dirlist3.txt -ll debug +rjm_batch_wait -f dirlist3.txt -ll debug diff --git a/pyproject.toml b/pyproject.toml index 344e3ecf..ab1f864c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "fair-research-login", "globus-compute-sdk==3.7.0", "globus-sdk", + "paramiko", "requests", "retry", ] diff --git a/src/rjm/auth.py b/src/rjm/auth.py index c6c5748c..35e0bdd4 100644 --- a/src/rjm/auth.py +++ b/src/rjm/auth.py @@ -21,6 +21,18 @@ def do_authentication(force=False, verbose=False, retry=True): sys.stderr.write("ERROR: configuration file must be created with rjm_configure before running rjm_authenticate" + os.linesep) sys.exit(1) + # Load the configuration + config = config_helper.load_config() + runner = config.get("COMPONENTS", "runner") + transferer = config.get("COMPONENTS", "transferer") + + # Disallow paramiko SSH runner and SFTP transferer for authentication + if runner == "paramiko_ssh_runner" and transferer == "paramiko_sftp_transferer": + logger.debug("Authentication not required for paramiko SSH") + if verbose: + print("RJM authentication not required for paramiko SSH") + return + # delete token file if exists if force: if os.path.isfile(utils.TOKEN_FILE_LOCATION): diff --git a/src/rjm/cli/rjm_config.py b/src/rjm/cli/rjm_config.py index 2ecb7c0a..5804f813 100644 --- a/src/rjm/cli/rjm_config.py +++ b/src/rjm/cli/rjm_config.py @@ -32,6 +32,8 @@ def make_parser(): parser.add_argument('-ll', '--loglevel', required=False, help="level of log verbosity (setting the level here overrides the config file)", choices=['debug', 'info', 'warn', 'error', 'critical']) + parser.add_argument('-s', '--ssh', action='store_true', + help='Generate an SSH key pair (stored under ~/.rjm) for use with the Paramiko runner. When this option is chosen, Globus Transfer/Compute setup is skipped.') parser.add_argument('-w', '--where-config', action="store_true", help="Print location of the config file and exit") parser.add_argument('-v', '--version', action="version", version='%(prog)s ' + __version__) @@ -46,6 +48,9 @@ def nesi_setup(): # command line args parser = make_parser() args = parser.parse_args() + # Determine whether Globus setup should be performed. + # If SSH option is chosen, we skip Globus (mutually exclusive behavior). + no_globus = args.ssh if args.where_config: # print location of config file and exit @@ -68,61 +73,121 @@ def nesi_setup(): logger = logging.getLogger(__name__) logger.info(f"Running rjm_config v{__version__}") - print() - print("="*120) - print() - print("This is an interactive script to configure RJM for accessing NeSI. " - "You will be required to enter information along the way, including your NeSI username and project code.") - print() - print("="*120) - print() - print("At times either a browser window will be automatically opened, or you will be asked to copy a link and open it " - "in a browser, where you will be asked to authenticate and allow RJM to have access. " - "Please ensure the default browser on your system is set to a modern and reasonably up to date browser.") - print() - print("="*120) - print() - print("In some situations a new link will be opened in your browser immediately after you authenticated the last one, " - "which can be easy to miss, so if it looks like nothing is happening, please check your browser window for a pending authentication.") - print() - print("="*120) - print() + if not no_globus: + print() + print("="*120) + print() + print("This is an interactive script to configure RJM for accessing NeSI. " + "You will be required to enter information along the way, including your NeSI username and project code.") + print() + print("="*120) + print() + print("At times either a browser window will be automatically opened, or you will be asked to copy a link and open it " + "in a browser, where you will be asked to authenticate and allow RJM to have access. " + "Please ensure the default browser on your system is set to a modern and reasonably up to date browser.") + print() + print("="*120) + print() + print("In some situations a new link will be opened in your browser immediately after you authenticated the last one, " + "which can be easy to miss, so if it looks like nothing is happening, please check your browser window for a pending authentication.") + print() + print("="*120) + print() - # get extra info from user - username = input(f"Enter NeSI username or press enter to accept default [{getpass.getuser()}]: ").strip() or getpass.getuser() - account = input("Enter NeSI project code or press enter to accept default (you must belong to it) [uoa00106]: ").strip() or "uoa00106" - print("="*120) + # get extra info from user + username = input(f"Enter NeSI username or press enter to accept default [{getpass.getuser()}]: ").strip() or getpass.getuser() + account = input("Enter NeSI project code or press enter to accept default (you must belong to it) [uoa00106]: ").strip() or "uoa00106" + print("="*120) + + else: + print() + print("="*120) + print() + print("This is an interactive script to configure RJM for accessing a remote machine via SSH. " + "You will be required to enter information along the way, including your username on the remote machine.") + print() + print("="*120) + print() + + # get extra info from user + username = input(f"Enter remote username or press enter to accept default [{getpass.getuser()}]: ").strip() or getpass.getuser() + account = None + print("="*120) # create the setup object nesi = NeSISetup(username, account) - # do the globus setup first because it is more interactive - nesi.setup_globus_transfer() + # do the globus setup unless the user asked to skip it + if not no_globus: + # This step is interactive and may open a browser + nesi.setup_globus_transfer() + + # If the user asked for an SSH key‑pair, generate it now via the new helper + if args.ssh: + # This will prompt for the remote base path and create the key pair + nesi.setup_paramiko() + paramiko_cfg = nesi.get_paramiko_config() # write values to config file - req_opts = copy.deepcopy(config_helper.CONFIG_OPTIONS_REQUIRED) + req_opts = copy.deepcopy(config_helper.CONFIG_OPTIONS) - # get config values - globus_ep, globus_path = nesi.get_globus_transfer_config() - funcx_ep = nesi.get_globus_compute_config() + # get config values (only if globus setup was performed) + if not no_globus: + globus_ep, globus_path = nesi.get_globus_transfer_config() + funcx_ep = nesi.get_globus_compute_config() # modify dict to set values as defaults done_globus_ep = False done_globus_path = False done_funcx_ep = False + + # Populate overrides – Globus overrides are applied only when Globus was run. + # Paramiko overrides are applied only when the SSH option was chosen. for optd in req_opts: - if optd["section"] == "GLOBUS" and optd["name"] == "remote_endpoint": - optd["override"] = globus_ep - done_globus_ep = True - elif optd["section"] == "GLOBUS" and optd["name"] == "remote_path": - optd["override"] = globus_path - done_globus_path = True - elif optd["section"] == "FUNCX" and optd["name"] == "remote_endpoint": - optd["override"] = funcx_ep - done_funcx_ep = True - assert done_globus_ep - assert done_globus_path - assert done_funcx_ep + # ----- Globus overrides (only when Globus setup was run) ----- + if not no_globus: + if optd["section"] == "GLOBUS_TRANSFER" and optd["name"] == "remote_endpoint": + optd["override"] = globus_ep + done_globus_ep = True + elif optd["section"] == "GLOBUS_TRANSFER" and optd["name"] == "remote_path": + optd["override"] = globus_path + done_globus_path = True + elif optd["section"] == "GLOBUS_COMPUTE" and optd["name"] == "remote_endpoint": + optd["override"] = funcx_ep + done_funcx_ep = True + + # ----- set the transferer and runner based on whether ssh option was chosen or not + if optd["section"] == "COMPONENTS": + if optd["name"] == "runner": + optd["override"] = "globus_compute_slurm_runner" + elif optd["name"] == "transferer": + optd["override"] = "globus_https_transferer" + + # ----- Paramiko overrides (only when SSH key pair was generated) ----- + if args.ssh: + if optd["section"] == "PARAMIKO": + if optd["name"] == "private_key_file": + optd["override"] = paramiko_cfg["private_key_file"] + elif optd["name"] == "remote_user": + optd["override"] = paramiko_cfg["remote_user"] + elif optd["name"] == "remote_base_path": + optd["override"] = paramiko_cfg["remote_base_path"] + elif optd["name"] == "remote_address": + # Store the remote machine address entered during Paramiko setup + optd["override"] = paramiko_cfg["remote_address"] + + # ----- set the transferer and runner based on whether ssh option was chosen or not + if optd["section"] == "COMPONENTS": + if optd["name"] == "runner": + optd["override"] = "paramiko_ssh_runner" + elif optd["name"] == "transferer": + optd["override"] = "paramiko_sftp_transferer" + + # sanity checks – only required when Globus overrides were attempted + if not no_globus: + assert done_globus_ep + assert done_globus_path + assert done_funcx_ep # backup current config if any if os.path.exists(config_helper.CONFIG_FILE_LOCATION): @@ -131,15 +196,17 @@ def nesi_setup(): print("="*120) # call method to set config file - config_helper.do_configuration(required_options=req_opts, accept_defaults=True) + config_helper.do_configuration(config_options=req_opts) print("="*120) print("Configuration file has been updated") print("="*120) - print("Running authenticate next...") - # force fresh authentication - do_authentication(force=True, verbose=True) + # Run authentication only if Globus steps were not skipped + if not no_globus: + print("Running authenticate next...") + # force fresh authentication + do_authentication(force=True, verbose=True) print("="*120) print("You should be ready to start using rjm now") diff --git a/src/rjm/cli/rjm_health_check.py b/src/rjm/cli/rjm_health_check.py index f55642f0..d6e84088 100644 --- a/src/rjm/cli/rjm_health_check.py +++ b/src/rjm/cli/rjm_health_check.py @@ -15,8 +15,10 @@ import tempfile from rjm import __version__ +from rjm.errors import RemoteJobRunnerError from rjm.remote_job import RemoteJob from rjm import utils +from rjm.runners.paramiko_ssh_runner import ParamikoSSHRunner def make_parser(): @@ -26,7 +28,7 @@ def make_parser(): parser.add_argument('-ll', '--loglevel', default="critical", help="level of log verbosity (default: %(default)s)", choices=['debug', 'info', 'warn', 'error', 'critical']) - parser.add_argument('-le', '--logextra', action='store_true', help='Also log funcx and globus at the chosen loglevel') + parser.add_argument('-le', '--logextra', action='store_true', help='Also log globus and paramiko at the chosen loglevel') parser.add_argument('-k', '--keep', action="store_true", help="Keep health check files on remote system, i.e. do not delete them after completing the check (default=%(default)s)") parser.add_argument('-r', '--retries', action='store_true', help='Allow retries on function failures') @@ -54,6 +56,53 @@ def _remote_health_check(remote_dir, remote_file, keep): os.rmdir(remote_dir) +def _remote_health_check_paramiko(runner, remote_dir, remote_file, keep): + """ + Verify that a remote directory and a file inside it exist when using + ParamikoSSHRunner. Checks are performed via runner.run_command with POSIX test commands. + Returns None on success or an error string on failure. + """ + logger = logging.getLogger(__name__) + + # Check directory exists + cmd_dir = f"test -d '{remote_dir}'" + logger.debug("Testing for remote directory existence: {remote_dir}") + try: + runner.run_command(cmd_dir, background=False, retries=False) + except RemoteJobRunnerError as exc: + logger.error(f"Remote directory does not exist: '{remote_dir}'") + raise exc + else: + logger.debug("Remote directory exists") + + # Check file exists + remote_path = os.path.join(remote_dir, remote_file) + cmd_file = f"test -f '{remote_path}'" + logger.debug(f"Testing remote file exists: {remote_path}") + try: + runner.run_command(cmd_file, background=False, retries=False) + except RemoteJobRunnerError as exc: + logger.error(f"Remote file does not exist: '{remote_path}'") + raise exc + else: + logger.debug("Remote file exists") + + # Optional cleanup + if not keep: + # remove file + cmd_rm = f"rm -f '{remote_path}'" + try: + runner.run_command(cmd_rm, background=False, retries=False) + except RemoteJobRunnerError as exc: + logger.warning(f"Failed to delete remote file '{remote_path}' ({exc})") + # remove directory + cmd_rmdir = f"rmdir '{remote_dir}'" + try: + runner.run_command(cmd_rmdir, background=False, retries=False) + except RemoteJobRunnerError as exc: + logger.warning(f"Failed to delete remote direcotry '{remote_dir} ({exc})") + + def health_check(): # command line arguments parser = make_parser() @@ -102,8 +151,11 @@ def health_check(): print() print("Using runner to check directory and file exist...") logger.debug("Using runner to check directory and file exist...") - run_function = r.run_function_with_retries if args.retries else r.run_function - result = run_function(_remote_health_check, remote_dir, test_file_name, args.keep) + if isinstance(r, ParamikoSSHRunner): + result = _remote_health_check_paramiko(r, remote_dir, test_file_name, args.keep) + else: + run_function = r.run_function_with_retries if args.retries else r.run_function + result = run_function(_remote_health_check, remote_dir, test_file_name, args.keep) if result is None: print("Finished checking directory and file exist") logger.debug("Finished checking directory and file exist") diff --git a/src/rjm/cli/tests/test_rjm_batch_wait.py b/src/rjm/cli/tests/test_rjm_batch_wait.py index 31223b7e..ddc3d7c7 100644 --- a/src/rjm/cli/tests/test_rjm_batch_wait.py +++ b/src/rjm/cli/tests/test_rjm_batch_wait.py @@ -10,15 +10,17 @@ @pytest.fixture def configobj(): config = configparser.ConfigParser() - config["GLOBUS"] = { + config["GLOBUS_TRANSFER"] = { "remote_endpoint": "qwerty", "remote_path": "asdfg", } - config["FUNCX"] = { + config["GLOBUS_COMPUTE"] = { "remote_endpoint": "abcdefg", } config["SLURM"] = { "slurm_script": "run.sl", + } + config["POLLING"] = { "poll_interval": "2", "warmup_poll_interval": "1", "warmup_duration": "3", @@ -32,6 +34,10 @@ def configobj(): "uploads_file": "uploads.txt", "downloads_file": "downloads.txt", } + config["COMPONENTS"] = { + "runner": "globus_compute_slurm_runner", + "transferer": "globus_https_transferer", + } return config diff --git a/src/rjm/config.py b/src/rjm/config.py index f6123c37..35103af0 100644 --- a/src/rjm/config.py +++ b/src/rjm/config.py @@ -15,27 +15,38 @@ ".rjm", "rjm_config.ini" ) -CONFIG_OPTIONS_REQUIRED = [ + +CONFIG_OPTIONS = [ # default values must be strings or None + { + "section": "COMPONENTS", + "name": "runner", + "default": "globus_compute_slurm_runner", + "help": "Enter the runner implementation that should be used", + }, + { + "section": "COMPONENTS", + "name": "transferer", + "default": "globus_https_transferer", + "help": "Enter the runner implementation that should be used", + }, { - "section": "GLOBUS", + "section": "GLOBUS_TRANSFER", "name": "remote_endpoint", "default": None, "help": "Enter the endpoint id of the Globus guest collection on the remote machine", }, { - "section": "GLOBUS", + "section": "GLOBUS_TRANSFER", "name": "remote_path", "default": None, "help": "Enter the absolute path to the root of the Globus guest collection on the remote machine", }, { - "section": "FUNCX", + "section": "GLOBUS_COMPUTE", "name": "remote_endpoint", "default": None, "help": "Enter the endpoint id of the Globus Compute endpoint running on the remote machine", }, -] -CONFIG_OPTIONS_OPTIONAL = [ # default values must be strings { "section": "SLURM", "name": "slurm_script", @@ -43,19 +54,19 @@ "help": "Name of the Slurm script that will be included in the uploaded files", }, { - "section": "SLURM", + "section": "POLLING", "name": "warmup_poll_interval", "default": "10", "help": "Interval (in seconds) between checking whether the Slurm job has completed during the initial phase", }, { - "section": "SLURM", + "section": "POLLING", "name": "warmup_duration", "default": "120", "help": "Duration (in seconds) during which we apply the `warmup_poll_interval` before switching to `poll_interval`", }, { - "section": "SLURM", + "section": "POLLING", "name": "poll_interval", "default": "60", "help": "Interval (in seconds) between checking whether the Slurm job has completed", @@ -72,6 +83,36 @@ "default": "rjm_downloads.txt", "help": "Name of the file in the local directory that lists files to be downloaded", }, + { + "section": "PARAMIKO", + "name": "private_key_file", + "default": os.path.join(os.path.expanduser("~"), ".rjm", "paramiko_private_key"), + "help": "Path to the file containing the private key that paramiko should use to connect to the remote system", + }, + { + "section": "PARAMIKO", + "name": "remote_address", + "default": None, + "help": "Address of the remote machine", + }, + { + "section": "PARAMIKO", + "name": "remote_user", + "default": None, + "help": "User to connect to the remote machine as", + }, + { + "section": "PARAMIKO", + "name": "remote_base_path", + "default": None, + "help": "Base directory that paramiko should work under on the remote machine", + }, + { + "section": "PARAMIKO", + "name": "job_script", + "default": "run.sl", + "help": "Name of the script to execute on the remote machine when starting a job", + }, ] @@ -81,6 +122,41 @@ def load_config(config_file=CONFIG_FILE_LOCATION): if os.path.exists(config_file): config = configparser.ConfigParser() config.read(config_file) + + # check if the config file is old and raise error if so + old_format = False + if not "COMPONENTS" in config: + old_format = True + logger.debug("Old format config file detected -- no COMPONENTS section -- defaulting to Globus") + config["COMPONENTS"] = { + "runner": "globus_compute_slurm_runner", + "transferer": "globus_https_transferer", + } + + if not "GLOBUS_TRANSFER" in config: + old_format = True + logger.debug("Old format config file detected -- no GLOBUS_TRANSFER section -- attempting to fix") + if "GLOBUS" in config: + logger.debug("Using GLOBUS config for GLOBUS_TRANSFER") + config["GLOBUS_TRANSFER"] = config["GLOBUS"] + + if not "GLOBUS_COMPUTE" in config: + old_format = True + logger.debug("Old format config file detected -- no GLOBUS_COMPUTE section -- attempting to fix") + if "FUNCX" in config: + logger.debug("Using FUNCX config for GLOBUS_COMPUTE") + config["GLOBUS_COMPUTE"] = config["FUNCX"] + + if not "POLLING" in config: + old_format = True + logger.debug("Old format config file detected -- no POLLING section -- attempting to fix") + if "SLURM" in config: + logger.debug("Using SLURM config for POLLING") + config["POLLING"] = config["SLURM"] + + if old_format: + logger.warning("Attempted to automatically update your old config file -- rerun `rjm_config` to avoid this") + else: raise RemoteJobConfigError(f"Config file does not exist: {config_file}") @@ -104,13 +180,17 @@ def load_or_make_config(config_file=CONFIG_FILE_LOCATION): return config -def _process_option(config, optd, ask=True): +def _process_option(config, optd): section = optd["section"] name = optd["name"] default = optd["default"] text = optd["help"] logger.debug(f"Processing option: {section}:{name} ({text}) : {default}") + not_required = False + if default is None: + not_required = True + # current value if any try: value = config[section][name] @@ -133,68 +213,37 @@ def _process_option(config, optd, ask=True): override = True logger.debug(f"Overriding config value with: {value}") - # user input - if ask and not override: - print() - msg = f"{text} [{value if value is not None else ''}]: " - new_value = input(msg).strip() - while value is None and not len(new_value): - new_value = input(msg).strip() - if len(new_value): - value = new_value - logger.debug(f"Got new value from user input: {value}") - # check we got a value - if value is None: + if value is None and not not_required: raise RuntimeError(f"No value provided for '{section}:{name}' (ask={ask}; override={override})") # store - if not config.has_section(section): - config.add_section(section) - config[section][name] = value + if value is not None: + if not config.has_section(section): + config.add_section(section) + config[section][name] = value -def do_configuration(required_options=CONFIG_OPTIONS_REQUIRED, - optional_options=CONFIG_OPTIONS_OPTIONAL, accept_defaults=False): +def do_configuration(config_options=CONFIG_OPTIONS): """ Run through configuration steps - :param required_options: optional, the list of required options, see config.CONFIG_OPTIONS_REQUIRED - for the expected format and default - :param optional_options: optional, the list of required options, see config.CONFIG_OPTIONS_OPTIONAL - for the expected format and default - :param accept_defaults: optional, if True then accept the default values - without requesting confirmation, defaults to False + :param config_options: optional, the list of config options, see config.CONFIG_OPTIONS + for the expected format and defaults """ logger.debug("Configuring RJM...") print("Configuring RJM...") - if not accept_defaults: - print("Please enter configuration values below or accept the defaults (in square brackets)") # load config file if it already exists logger.debug("Load existing config or make a new config object") config = load_or_make_config() # loop over the options, asking user for input - logger.debug("Processing required options") - for optd in required_options: + logger.debug("Processing config options") + for optd in config_options: _process_option(config, optd) - # do they want to configure the rest? - print() - use_defaults = "y" if accept_defaults else "" - while use_defaults not in ("y", "n"): - use_defaults = input("Do you wish to use default values for the remaining options (y/n)? ") - if use_defaults == "y": - ask = False - else: - ask = True - - logger.debug("Processing optional options") - for optd in optional_options: - _process_option(config, optd, ask=ask) - # store configuration with open(CONFIG_FILE_LOCATION, 'w') as cf: config.write(cf) diff --git a/src/rjm/remote_job.py b/src/rjm/remote_job.py index e6d107fd..c15bb237 100644 --- a/src/rjm/remote_job.py +++ b/src/rjm/remote_job.py @@ -10,7 +10,9 @@ from rjm import utils from rjm import config as config_helper from rjm.transferers import globus_https_transferer +from rjm.transferers import paramiko_sftp_transferer from rjm.runners.globus_compute_slurm_runner import GlobusComputeSlurmRunner +from rjm.runners.paramiko_ssh_runner import ParamikoSSHRunner from rjm.errors import RemoteJobRunnerError @@ -53,10 +55,18 @@ def __init__(self, timestamp=None): self._retry_tries, self._retry_backoff, self._retry_delay, self._retry_max_delay = utils.get_retry_values_from_config(config) # file transferer - self._transfer = globus_https_transferer.GlobusHttpsTransferer(config=config) + transferer_type = config.get("COMPONENTS", "transferer") + if transferer_type == "paramiko_sftp_transferer": + self._transfer = paramiko_sftp_transferer.ParamikoSftpTransferer(config=config) + else: + self._transfer = globus_https_transferer.GlobusHttpsTransferer(config=config) # remote runner - self._runner = GlobusComputeSlurmRunner(config=config) + runner_type = config.get("COMPONENTS", "runner") + if runner_type == "paramiko_ssh_runner": + self._runner = ParamikoSSHRunner(config=config) + else: + self._runner = GlobusComputeSlurmRunner(config=config) def files_uploaded(self): """Return whether files have been uploaded""" @@ -217,11 +227,11 @@ def do_globus_auth(self, runner=None, transfer=None): # setup runner self._log(logging.DEBUG, "Setting up globus auth for runner") - self._runner.setup_globus_auth(globus_cli, runner=runner) + self._runner.setup(globus_cli, runner=runner) # setup transferer self._log(logging.DEBUG, "Setting up globus auth for transferer") - self._transfer.setup_globus_auth(globus_cli, transfer=transfer) + self._transfer.setup(globus_cli, transfer=transfer) def cleanup(self): """ @@ -382,6 +392,13 @@ def run_wait(self, polling_interval=None, warmup_polling_interval=None, warmup_d self.set_run_completed(success=run_succeeded) self._save_state() + def check_job_status(self): + """ + Check whether the job has finished and return a string with the state of the job + + """ + return self._runner.check_job_status() + def run_cancel(self): """Cancel the run.""" if not self._run_started: diff --git a/src/rjm/remote_job_batch.py b/src/rjm/remote_job_batch.py index 5dfc92ae..1e6ad8a9 100644 --- a/src/rjm/remote_job_batch.py +++ b/src/rjm/remote_job_batch.py @@ -12,6 +12,9 @@ from rjm.remote_job import RemoteJob from rjm.runners.globus_compute_slurm_runner import GlobusComputeSlurmRunner from rjm.transferers.globus_https_transferer import GlobusHttpsTransferer +from rjm import config as config_helper +from rjm.runners.paramiko_ssh_runner import ParamikoSSHRunner +from rjm.transferers.paramiko_sftp_transferer import ParamikoSftpTransferer logger = logging.getLogger(__name__) @@ -24,8 +27,23 @@ class RemoteJobBatch: """ def __init__(self): self._remote_jobs = [] - self._runner = GlobusComputeSlurmRunner() - self._transfer = GlobusHttpsTransferer() + + # Load configuration to decide which components to use + config = config_helper.load_config() + + # Choose runner based on config COMPONENTS.runner + runner_type = config.get("COMPONENTS", "runner") + if runner_type == "paramiko_ssh_runner": + self._runner = ParamikoSSHRunner(config=config) + else: + self._runner = GlobusComputeSlurmRunner(config=config) + + # Choose transferer based on config COMPONENTS.transferer + transferer_type = config.get("COMPONENTS", "transferer") + if transferer_type == "paramiko_sftp_transferer": + self._transfer = ParamikoSftpTransferer(config=config) + else: + self._transfer = GlobusHttpsTransferer(config=config) def setup(self, remote_jobs_file: str, force: bool = False): """Setup the runner""" @@ -35,9 +53,13 @@ def setup(self, remote_jobs_file: str, force: bool = False): # Globus auth scopes = self._runner.get_globus_scopes() scopes.extend(self._transfer.get_globus_scopes()) - globus_cli = utils.handle_globus_auth(scopes) - self._runner.setup_globus_auth(globus_cli) - self._transfer.setup_globus_auth(globus_cli) + if len(scopes): + globus_cli = utils.handle_globus_auth(scopes) + else: + globus_cli = None + + self._runner.setup(globus_cli) + self._transfer.setup(globus_cli) # read the list of local directories and create RemoteJobs local_dirs = self._read_jobs_file(remote_jobs_file) diff --git a/src/rjm/runners/globus_compute_slurm_runner.py b/src/rjm/runners/globus_compute_slurm_runner.py index c1b1c878..df4fb54e 100644 --- a/src/rjm/runners/globus_compute_slurm_runner.py +++ b/src/rjm/runners/globus_compute_slurm_runner.py @@ -38,7 +38,7 @@ def __init__(self, config=None): self._setup_done = False # the Globus Compute endpoint on the remote machine - self._endpoint = self._config.get("FUNCX", "remote_endpoint") + self._endpoint = self._config.get("GLOBUS_COMPUTE", "remote_endpoint") # globus compute login manager self._login_manager = CustomLoginManager() @@ -52,9 +52,9 @@ def __init__(self, config=None): self._slurm_script = self._config.get("SLURM", "slurm_script") # how often to poll for Slurm job completion - self._poll_interval = self._config.getint("SLURM", "poll_interval") - self._warmup_poll_interval = self._config.getint("SLURM", "warmup_poll_interval") - self._warmup_duration = self._config.getint("SLURM", "warmup_duration") + self._poll_interval = self._config.getint("POLLING", "poll_interval") + self._warmup_poll_interval = self._config.getint("POLLING", "warmup_poll_interval") + self._warmup_duration = self._config.getint("POLLING", "warmup_duration") # Slurm job id self._jobid = None @@ -88,6 +88,10 @@ def get_globus_scopes(self): """If any Globus scopes are required, override this method and return them in a list""" return self._login_manager.get_scopes() + def setup(self, globus_cli, runner=None): + """Set up the transferer""" + return self.setup_globus_auth(globus_cli, runner=runner) + def setup_globus_auth(self, globus_cli, runner=None): """Do any Globus auth setup here, if required""" self._setup_done = True @@ -297,6 +301,14 @@ def start(self, working_directory): return started + def check_directory_exists(self, directory_path): + """Check the working directory exists""" + # sanity check the directory exists on the remote + dir_exists = self.run_function_with_retries(check_dir_exists, directory_path) + if not dir_exists: + self._log(logging.ERROR, f"The specified directory does not exist on remote: {directory_path}") + raise RemoteJobRunnerError(f"The specified directory does not exist on remote: {directory_path}") + def wait(self, polling_interval=None, warmup_polling_interval=None, warmup_duration=None): """ Wait for the Slurm job to finish @@ -418,7 +430,8 @@ def check_finished_jobs(self, remote_jobs): :param remote_jobs: list of remote jobs to check :returns: tuple of lists of RemoteJobs containing: - - finished jobs + - successful jobs + - failed jobs - unfinished jobs """ @@ -711,3 +724,9 @@ def _get_authorizer(self, resource_server): # return the selected authoriser return authorisers[resource_server] + + +# function for checking directory exists on funcx endpoint +def check_dir_exists(dirpath): + import os + return os.path.isdir(dirpath) diff --git a/src/rjm/runners/paramiko_ssh_runner.py b/src/rjm/runners/paramiko_ssh_runner.py new file mode 100644 index 00000000..de76c91c --- /dev/null +++ b/src/rjm/runners/paramiko_ssh_runner.py @@ -0,0 +1,435 @@ + +import uuid +import os +import time +import logging +import paramiko + +from retry.api import retry_call + +from rjm.runners.runner_base import RunnerBase +from rjm.errors import RemoteJobRunnerError + + +MIN_POLLING_INTERVAL = 60 +MIN_WARMUP_POLLING_INTERVAL = 10 +MAX_WARMUP_DURATION = 300 + +logger = logging.getLogger(__name__) + + +class ParamikoSSHRunner(RunnerBase): + """ + Runner that uses the Paramiko SSH client to execute a command in a tmux + session on the remote machine and poll until the command completes. + + It is up to the calling program to manage the amount of work submitted + concurrently. + + """ + def __init__(self, config=None): + super(ParamikoSSHRunner, self).__init__(config=config) + + self._setup_done = False + self._ssh_client = None + + # config + self._ssh_private_key_file = self._config.get("PARAMIKO", "private_key_file") + self._remote_address = self._config.get("PARAMIKO", "remote_address") + self._remote_user = self._config.get("PARAMIKO", "remote_user") + self._job_script = self._config.get("PARAMIKO", "job_script") + + # how often to poll for job completion + self._poll_interval = self._config.getint("POLLING", "poll_interval") + self._warmup_poll_interval = self._config.getint("POLLING", "warmup_poll_interval") + self._warmup_duration = self._config.getint("POLLING", "warmup_duration") + + # tmux session name + self._tmux_session_name = None + self._working_directory = None + + def _log(self, level, message, *args, **kwargs): + """Add a label to log messages, identifying this specific RemoteJob""" + logger.log(level, self._label + message, *args, **kwargs) + + def __repr__(self): + return f"ParamikoSSHRunner({self._remote_user}@{self._remote_address})" + + def save_state(self): + """Append state to state_dict if required for restarting""" + state_dict = super(ParamikoSSHRunner, self).save_state() + if self._tmux_session_name is not None: + state_dict["tmux_session_name"] = self._tmux_session_name + if self._working_directory is not None: + state_dict["working_directory"] = self._working_directory + + return state_dict + + def __del__(self): + if self._ssh_client is not None: + self._ssh_client.close() + + def load_state(self, state_dict): + """Get saved state if required for restarting""" + super(ParamikoSSHRunner, self).load_state(state_dict) + if "tmux_session_name" in state_dict: + self._tmux_session_name = state_dict["tmux_session_name"] + if "working_directory" in state_dict: + self._working_directory = state_dict["working_directory"] + + def setup(self, *args, **kwargs): + """Setup the SFTP client""" + self._log(logging.DEBUG, "Setting up ParamikoSSHRunner...") + self._ssh_client = paramiko.SSHClient() + self._ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + # load the private key + private_key = paramiko.RSAKey(filename=self._ssh_private_key_file) + + # Connect to server + self._ssh_client.connect( + hostname=self._remote_address, + port=22, + username=self._remote_user, + pkey=private_key, + timeout=30, + ) + self._log(logging.DEBUG, f"Connected to: {self._remote_user}@{self._remote_address} ({self._ssh_client})") + + self._setup_done = True + + def run_command(self, command, background=False, retries=False): + """ + Run the given command on the remote machine. + + Returns the output of the command if `background` is `False`; + returns the tmux session name if `background` is `True`. + + :param command: The command to run on the remote machine + :type command: str + :param background: Whether to run the command in the background using tmux + :type background: bool + :param retries: Whether to retry the command if it fails + :type retries: bool + + """ + if self._ssh_client is None: + raise RuntimeError("Must call setup before run_command") + + if retries: + self._log(logging.ERROR, "Retries are not implemented for the paramiko runner yet") + + self._log(logging.DEBUG, f"Running {'background ' if background else ''}command: {command}") + + if background: + session_name = f"rjm-{uuid.uuid4()}" + + # Escape single quotes in command for safe shell use + escaped_command = command.replace("'", r"'\''") + + # Full tmux command to create a new detached session and run the command + command = ( + f"tmux new-session -d -s '{session_name}' '{escaped_command}'" + ) + self._log(logging.DEBUG, f"Full background command: {command}") + + stdin, stdout, stderr = self._ssh_client.exec_command(command) + + stdout_output = stdout.read().decode().strip() + stderr_output = stderr.read().decode().strip() + full_output_not_time_ordered = stdout_output + stderr_output + + exit_code = stdout.channel.recv_exit_status() + + if exit_code: + raise RemoteJobRunnerError(f"run_command failed (exit code {exit_code}): STDOUT: {stdout_output}; STDERR: {stderr_output}") + + if background: + retval = session_name + else: + retval = full_output_not_time_ordered + + return retval + + def make_remote_directory(self, remote_base_path, prefix, retries=True): + """ + Make one or more remote directories, using the given prefix(es). + + :param remote_base_path: The base path on the remote machine to create + the directories in + :param prefix: Single prefix, or list of prefixes, for remote directories + + :return: Tuple, or list of tuples, containing the full remote path and + remote path relative to the base path + + """ + # remote function expects a list + if type(prefix) is not list: + single = True + prefix_list = [prefix] + else: + single = False + prefix_list = prefix + + # with or without retries + if retries: + self._log(logging.WARNING, "Retries not implemented yet (make_remote_directory)") + + # run the remote function + self._log(logging.DEBUG, f"Creating remote directories for: {prefix_list}") + self._log(logging.DEBUG, f"Creating remote directories in: {remote_base_path}") + + try: + remote_dirs = [] + for p in prefix_list: + # Construct the mktemp command to create a directory in base_path with the given prefix + cmd = f"mktemp -d -p '{remote_base_path}' -t '{p}-XXXXXX'" + + stdin, stdout, stderr = self._ssh_client.exec_command(cmd) + + # Wait for command to complete + exit_status = stdout.channel.recv_exit_status() + + if exit_status != 0: + error_msg = stderr.read().decode().strip() + raise RuntimeError(f"Failed to create temp directory with prefix '{p}': {error_msg}") + + # Read the stdout (the created directory path) + remote_full_path = stdout.read().decode().strip() + + # Ensure base_path is normalized for relative path calculation + base_path_clean = remote_base_path.rstrip('/') + if not remote_full_path.startswith(base_path_clean): + raise RuntimeError(f"Created directory {remote_full_path} is not under base path {remote_base_path}") + + # Compute relative path + rel_path = os.path.relpath(remote_full_path, start=remote_base_path) + + remote_dirs.append((remote_full_path, rel_path)) + + except Exception as exc: + raise RemoteJobRunnerError(f"Make remote directory failed: {exc}") + + if single: + remote_dirs = remote_dirs[0] + + return remote_dirs + + def check_directory_exists(self, directory_path): + """ + Check that the given directory exists on the remote machine + + :param directory_path + :raises RemoteJobRunnerError: if the directory does not exist + + """ + # Use SSH to test if the directory exists + command = f"test -d '{directory_path}'" + stdin, stdout, stderr = self._ssh_client.exec_command(command) + exit_status = stdout.channel.recv_exit_status() + if exit_status != 0: + raise RemoteJobRunnerError(f"Remote directory does not exist: {directory_path}") + # Directory exists; nothing to return + + def start(self, working_directory): + """ + Starts running the job script + + :param working_directory: the directory to run the script in + + """ + self._log(logging.DEBUG, f"Starting job for: {working_directory}") + try: + self._tmux_session_name = self.run_command( + f'cd "{working_directory}" && bash {self._job_script} > stdout.txt 2> stderr.txt && touch "{working_directory}/.rjm-succeeded"', + background=True + ) + + except RemoteJobRunnerError as exc: + self._log(logging.ERROR, f'Starting job failed in remote directory: "{working_directory}"') + self._log(logging.ERROR, f'{exc}') + started = False + raise exc + + self._log(logging.INFO, f'Started job with tmux session name: {self._tmux_session_name}') + self._working_directory = working_directory + started = True + + return started + + def check_job_status(self): + """ + Check the status of the job (SUCCEEDED, FAILED, UNFINISHED) + + """ + if self._tmux_session_name is None: + raise ValueError("Must call 'run_start' before 'check_job_status'") + + self._log(logging.DEBUG, f"Checking job status: {self._tmux_session_name}") + + cmd = f"tmux has-session -t \"{self._tmux_session_name}\"" + self._log(logging.DEBUG, f"Command: {cmd}") + stdin, stdout, stderr = self._ssh_client.exec_command(cmd) + + # Wait for command to complete + exit_status = stdout.channel.recv_exit_status() + stdout_output = stdout.read().decode().strip() + stderr_output = stderr.read().decode().strip() + self._log(logging.DEBUG, f'tmux has-session exit code: {exit_status}') + self._log(logging.DEBUG, f"STDOUT: {stdout_output}") + self._log(logging.DEBUG, f"STDERR: {stderr_output}") + + if exit_status: + job_finished = True + + # confirm whether or not it succeeded, i.e. does the file exist? + stdin, stdout, stderr = self._ssh_client.exec_command(f"test -f {self._working_directory}/.rjm-succeeded") + exit_status = stdout.channel.recv_exit_status() + if exit_status: + state = "FAILED" + else: + state = "SUCCEEDED" + + else: + state = "UNFINISHED" + + self._log(logging.DEBUG, f"Job status for {self._tmux_session_name}: {state}") + + return state + + def wait(self, polling_interval=None, warmup_polling_interval=None, warmup_duration=None): + """ + Wait for the job to finish + + Return True if the job succeeded, False if it failed + + """ + if self._tmux_session_name is None: + raise ValueError("Must call 'run_start' before 'run_wait'") + + # get the polling interval + polling_interval, warmup_polling_interval, warmup_duration = self.get_poll_interval( + polling_interval, warmup_polling_interval, warmup_duration + ) + + # loop until job has finished + self._log(logging.INFO, f"Waiting for job {self._tmux_session_name} to finish") + self._log(logging.DEBUG, f"Polling interval is: {polling_interval} seconds") + job_finished = False + job_succeeded = None + while not job_finished: + state = self.check_job_status() + + if state in ("SUCCEEDED", "FAILED"): + job_finished = True + + if state == "SUCCEEDED": + job_succeeded = True + else: + job_succeeded = False + + else: + self._log(logging.DEBUG, "Not finished yet") + time.sleep(polling_interval) + + assert job_succeeded is not None, "Unexpected error during wait" + if job_finished: + self._log(logging.INFO, f"Remote job {self._tmux_session_name} has finished (success: {job_succeeded})") + + return job_succeeded + + def cancel(self): + """Cancel the remote job""" + raise NotImplementedError() + + def get_poll_interval( + self, + requested_interval: int | None, + requested_warmup_interval: int | None, + requested_warmup_duration: int | None, + ): + """Returns the poll interval from Slurm config""" + if requested_interval is None: + polling_interval = self._poll_interval + self._log(logging.DEBUG, f"Using polling interval from config file: {polling_interval}") + else: + polling_interval = requested_interval + self._log(logging.DEBUG, f"Using requested polling interval: {polling_interval}") + + if polling_interval < MIN_POLLING_INTERVAL: + polling_interval = MIN_POLLING_INTERVAL + self._log(logging.WARNING, f"Overriding polling interval with minimum value: {polling_interval}") + + if requested_warmup_interval is None: + warmup_polling_interval = self._warmup_poll_interval + self._log(logging.DEBUG, f"Using warmup polling interval from config file: {warmup_polling_interval}") + else: + warmup_polling_interval = requested_warmup_interval + self._log(logging.DEBUG, f"Using requested warmup polling interval: {warmup_polling_interval}") + + if warmup_polling_interval < MIN_WARMUP_POLLING_INTERVAL: + warmup_polling_interval = MIN_WARMUP_POLLING_INTERVAL + self._log(logging.WARNING, f"Overriding warmup polling interval with minimum value: {warmup_polling_interval}") + + + if requested_warmup_duration is None: + warmup_duration = self._warmup_duration + self._log(logging.DEBUG, f"Using warmup duration from config file: {warmup_duration}") + else: + warmup_duration = requested_warmup_duration + self._log(logging.DEBUG, f"Using requested warmup duration: {warmup_duration}") + + if warmup_duration > MAX_WARMUP_DURATION: + warmup_duration = MAX_WARMUP_DURATION + self._log(logging.WARNING, f"Overriding warmup duration with maximum value: {warmup_duration}") + + return polling_interval, warmup_polling_interval, warmup_duration + + def check_finished_jobs(self, remote_jobs): + """ + Check whether jobs have finished + + :param remote_jobs: list of remote jobs to check + + :returns: tuple of lists of RemoteJobs containing: + - successful jobs + - failed jobs + - unfinished jobs + + """ + self._log(logging.DEBUG, "Checking for finished jobs") + successful_jobs = [] + failed_jobs = [] + unfinished_jobs = [] + for rj in remote_jobs: + self._log(logging.DEBUG, f"Check job status for: {rj}") + status = rj.check_job_status() + + if status == "SUCCEEDED": + successful_jobs.append(rj) + self._log(logging.DEBUG, f"{rj} finished successfully") + + elif status == "FAILED": + failed_jobs.append(rj) + self._log(logging.DEBUG, f"{rj} finished unsuccessfully") + + elif status == "UNFINISHED": + unfinished_jobs.append(rj) + self._log(logging.DEBUG, f"{rj} unfinished") + + else: + raise ValueError(f"Unrecognised job status: \"{status}\"") + + return successful_jobs, failed_jobs, unfinished_jobs + + def get_checksums(self, working_directory, files): + """ + Return SHA256 checksums for the list of files + + :param files: list of files to calculate checksums of + :param working_directory: directory to switch to first + + :returns: dictionary with file names as keys and checksums as values + + """ + raise NotImplementedError diff --git a/src/rjm/runners/runner_base.py b/src/rjm/runners/runner_base.py index b0dc97b7..ef3b05f5 100644 --- a/src/rjm/runners/runner_base.py +++ b/src/rjm/runners/runner_base.py @@ -66,11 +66,7 @@ def make_remote_directory(self, prefix: list[str]): def check_directory_exists(self, directory_path): """Check the working directory exists""" - # sanity check the directory exists on the remote - dir_exists = self.run_function_with_retries(check_dir_exists, directory_path) - if not dir_exists: - self._log(logging.ERROR, f"The specified directory does not exist on remote: {directory_path}") - raise RemoteJobRunnerError(f"The specified directory does not exist on remote: {directory_path}") + raise NotImplementedError def run_function(self, function, *args, **kwargs): """Run the given function and pass back the return value""" @@ -99,9 +95,3 @@ def cancel(self): def path_join(path1, path2): import os.path return os.path.join(path1, path2) - - -# function for checking directory exists on funcx endpoint -def check_dir_exists(dirpath): - import os - return os.path.isdir(dirpath) diff --git a/src/rjm/runners/tests/test_globus_compute_slurm_runner.py b/src/rjm/runners/tests/test_globus_compute_slurm_runner.py index cff54695..8188e1e3 100644 --- a/src/rjm/runners/tests/test_globus_compute_slurm_runner.py +++ b/src/rjm/runners/tests/test_globus_compute_slurm_runner.py @@ -12,11 +12,13 @@ @pytest.fixture def configobj(): config = configparser.ConfigParser() - config["FUNCX"] = { + config["GLOBUS_COMPUTE"] = { "remote_endpoint": "abcdefg", } config["SLURM"] = { "slurm_script": "run.sl", + } + config["POLLING"] = { "poll_interval": "2", "warmup_poll_interval": "1", "warmup_duration": "3", @@ -374,11 +376,13 @@ def test_check_slurm_job_statuses_missing(mocker): ]) def test_get_poll_interval(config_vals, user_vals, expected_vals, mocker): config = configparser.ConfigParser() - config["FUNCX"] = { + config["GLOBUS_COMPUTE"] = { "remote_endpoint": "abcdefg", } config["SLURM"] = { "slurm_script": "run.sl", + } + config["POLLING"] = { "poll_interval": str(config_vals[0]), "warmup_poll_interval": str(config_vals[1]), "warmup_duration": str(config_vals[2]), diff --git a/src/rjm/setup/nesi.py b/src/rjm/setup/nesi.py index 316f733f..1496232a 100644 --- a/src/rjm/setup/nesi.py +++ b/src/rjm/setup/nesi.py @@ -1,9 +1,12 @@ +""" +NeSI setup module for RJM. +""" import os -import time import uuid import logging import tempfile +import paramiko import globus_sdk from globus_sdk import GCSClient, TransferClient, DeleteData @@ -11,6 +14,7 @@ from globus_sdk.scopes import TransferScopes from rjm import utils +from rjm.runners.paramiko_ssh_runner import ParamikoSSHRunner logger = logging.getLogger(__name__) @@ -34,8 +38,14 @@ def __init__(self, username, account): self._account = account # initialise values we are setting up - self._globus_id = None # globus endpoint id - self._globus_path = None # path to globus share + self._globus_id = None # globus endpoint id + self._globus_path = None # path to globus share + self._remote_base_path = None # base path on the remote system for Paramiko + + # Paramiko key paths (filled by create_ssh_keypair) + self._private_key_path = None + self._public_key_path = None + self._remote_address = None def get_globus_compute_config(self): """Return globus compute config values""" @@ -51,7 +61,7 @@ def _handle_globus_auth(self, token_file, request_scopes, authoriser_scopes, by_ print("Authorising Globus - this should open a browser where you need to authenticate with Globus and approve access") print(" Globus is used by RJM to transfer files to and from NeSI") print("") - print("NOTE: If you are asked for a linked identity with NeSI Keycloak please do one of the following:") + print("NOTE: if you are asked for a linked identity with NeSI Keycloak please do one of the following:") print(f" - If you already have a linked identity it should appear in the list like: '{self._username}@iam.nesi.org.nz'") print(" If so, please select it and follow the instructions to authenticate with your NeSI credentials if required") print(" - Otherwise, choose the option to 'Link an identity from NeSI Keycloak'") @@ -205,7 +215,7 @@ def setup_globus_transfer(self): ) logger.debug(f"Collection document: {doc}") - # create Globus collection, report back endpoint id for config + # create the Globus collection, report back endpoint id for config response = client.create_collection(doc) endpoint_id = response.data["id"] logger.debug(f"Created Globus Guest Collection with Endpoint ID: {endpoint_id}") @@ -221,3 +231,171 @@ def setup_globus_transfer(self): # also store the endpoint id and path self._globus_id = endpoint_id self._globus_path = guest_collection_dir + + # --------------------------------------------------------------------- # + # SSH key‑pair handling + # --------------------------------------------------------------------- # + def create_ssh_keypair( + self, + private_key_path: str | None = None, + bits: int = 2048 + ) -> tuple[str, str]: + """Generate an SSH key pair with Paramiko and store it under ``~/.rjm``.""" + # Resolve the default location if the caller did not provide one + if private_key_path is None: + private_key_path = os.path.join( + os.path.expanduser("~"), ".rjm", "paramiko_private_key" + ) + public_key_path = private_key_path + ".pub" + + # ----------------------------------------------------------------- + # 1️⃣ Check for existing key pair + # ----------------------------------------------------------------- + if os.path.isfile(private_key_path) and os.path.isfile(public_key_path): + while True: + answer = input( + f"\nSSH key pair already exists at:\n" + f" private: {private_key_path}\n" + f" public : {public_key_path}\n" + f"Use the existing keys? (y/n): " + ).strip().lower() + if answer.startswith('y'): + logger.info("Re-using existing SSH key pair") + return private_key_path, public_key_path + if answer.startswith('n'): + logger.info("Generating a new SSH key pair (overwriting existing files)") + break + print("Please answer 'y' or 'n'.") + + # ----------------------------------------------------------------- + # 2️⃣ Generate a new RSA key pair + # ----------------------------------------------------------------- + # Ensure the target directory exists + os.makedirs(os.path.dirname(private_key_path), exist_ok=True) + + private_key = paramiko.RSAKey.generate(bits=bits) + + # Write private key + with open(private_key_path, 'w', encoding='utf-8') as private_file: + private_key.write_private_key(private_file) + + # Write public key (OpenSSH format) + public_key_str = f"{private_key.get_name()} {private_key.get_base64()}\n" + with open(public_key_path, 'w', encoding='utf-8') as public_file: + public_file.write(public_key_str) + + logger.info( + "Generated SSH key pair – private: %s, public: %s", + private_key_path, + public_key_path, + ) + + return private_key_path, public_key_path + + # --------------------------------------------------------------------- # + # Paramiko (SSH) setup + # --------------------------------------------------------------------- # + def setup_paramiko(self): + """ + Interactively set up the Paramiko runner: + + * Ask the user for a base directory on the remote system where RJM + should operate. + * Generate an SSH key‑pair (using :meth:`create_ssh_keypair`). + + The chosen base path is stored on the instance for later use and the + private/public key file paths are returned. + """ + # ----------------------------------------------------------------- + # 1️⃣ Ask for the remote machine address (IP or DNS name) + # ----------------------------------------------------------------- + remote_addr = input( + "Enter remote address (IP or DNS name) for Paramiko: " + ).strip() + if not remote_addr: + raise ValueError("Remote address is required for Paramiko setup") + self._remote_address = remote_addr + logger.info("Paramiko remote address set to: %s", remote_addr) + + # ----------------------------------------------------------------- + # 2️⃣ Ask for the remote base path (default under /tmp) + # ----------------------------------------------------------------- + default_path = f"/home/{self._username}/.cache/rjm" + remote_base = input( + f"Enter remote base path for Paramiko (default [{default_path}]): " + ).strip() or default_path + + # Store for later use + self._remote_base_path = remote_base + logger.info("Paramiko remote base path set to: %s", remote_base) + + # ----------------------------------------------------------------- + # 3️⃣ Generate the SSH key pair (stores paths on the instance) + # ----------------------------------------------------------------- + private_key, public_key = self.create_ssh_keypair() + + # Store the paths on the instance for later retrieval + self._private_key_path = private_key + self._public_key_path = public_key + + # ----------------------------------------------------------------- + # Inform the user how to install the public key on the remote host + # ----------------------------------------------------------------- + try: + with open(public_key, "r", encoding="utf-8") as pk_f: + pub_key_contents = pk_f.read().strip() + except Exception as exc: # pragma: no cover – defensive + logger.error("Failed to read generated public key: %s", exc) + pub_key_contents = "" + + print("\n" + "=" * 80) + print("Public key generated for Paramiko access.") + print("Copy the following line and paste it into the") + print("~/.ssh/authorized_keys file on the remote machine:") + print("-" * 80) + print(pub_key_contents) + print("-" * 80) + print("Note: the key only needs to be copied once - if you") + print("are reusing an existing key and have already copied") + print("it across previously, you don't need to copy it again.") + print("-" * 80) + print("After you have added the key, press ENTER to continue.") + print("=" * 80 + "\n") + # Wait for user confirmation + input("Press ENTER when the public key has been added to the remote authorized_keys file...") + + # ----------------------------------------------------------------- + # Test access and make sure the remote base path directory exists + # ----------------------------------------------------------------- + print() + print("Opening connection to the remote machine...") + runner = ParamikoSSHRunner() + runner.setup() + print(f"Creating remote directory if needed ({self._remote_base_path})...") + runner.run_command(f"mkdir -p {self._remote_base_path}") + print(f"Testing write access to remote directory...") + runner.run_command(f"test -d {self._remote_base_path} && test -w {self._remote_base_path}") + print("Finished test") + + # --------------------------------------------------------------------- # + # Helper to expose Paramiko configuration + # --------------------------------------------------------------------- # + def get_paramiko_config(self): + """ + Return the three Paramiko configuration values that RJM needs. + + Returns + ------- + dict + { + "private_key_file": , + "remote_user": , + "remote_base_path": + } + """ + return { + "private_key_file": self._private_key_path, + "remote_user": self._username, + "remote_base_path": self._remote_base_path, + "remote_address": self._remote_address, + } diff --git a/src/rjm/tests/test_config.py b/src/rjm/tests/test_config.py index a34be1ea..d33f982b 100644 --- a/src/rjm/tests/test_config.py +++ b/src/rjm/tests/test_config.py @@ -5,20 +5,26 @@ from rjm.errors import RemoteJobConfigError -CONFIG_FILE_TEST = """[GLOBUS] +CONFIG_FILE_TEST = """[GLOBUS_TRANSFER] remote_endpoint = abcdefg remote_path = /remote/path -[FUNCX] +[GLOBUS_COMPUTE] remote_endpoint = abcdefg [SLURM] slurm_script = run.sl + +[POLLING] poll_interval = 10 [FILES] uploads_file = rjm_uploads.txt downloads_file = rjm_downloads.txt + +[COMPONENTS] +runner = globus_compute_slurm_runner +transferer = globus_compute_https_transferer """ @@ -31,11 +37,11 @@ def config_file(tmp_path): def test_load_config(config_file): config = config_helper.load_config(config_file=str(config_file)) - assert config.get("GLOBUS", "remote_endpoint") == "abcdefg" - assert config.get("GLOBUS", "remote_path") == "/remote/path" - assert config.get("FUNCX", "remote_endpoint") == "abcdefg" + assert config.get("GLOBUS_TRANSFER", "remote_endpoint") == "abcdefg" + assert config.get("GLOBUS_TRANSFER", "remote_path") == "/remote/path" + assert config.get("GLOBUS_COMPUTE", "remote_endpoint") == "abcdefg" assert config.get("SLURM", "slurm_script") == "run.sl" - assert config.getint("SLURM", "poll_interval") == 10 + assert config.getint("POLLING", "poll_interval") == 10 assert config.get("FILES", "uploads_file") == "rjm_uploads.txt" assert config.get("FILES", "downloads_file") == "rjm_downloads.txt" diff --git a/src/rjm/tests/test_remote_job.py b/src/rjm/tests/test_remote_job.py index 08e12284..62a5df60 100644 --- a/src/rjm/tests/test_remote_job.py +++ b/src/rjm/tests/test_remote_job.py @@ -12,15 +12,17 @@ @pytest.fixture def configobj(): config = configparser.ConfigParser() - config["GLOBUS"] = { + config["GLOBUS_TRANSFER"] = { "remote_endpoint": "qwerty", "remote_path": "asdfg", } - config["FUNCX"] = { + config["GLOBUS_COMPUTE"] = { "remote_endpoint": "abcdefg", } config["SLURM"] = { "slurm_script": "run.sl", + } + config["POLLING"] = { "poll_interval": "2", "warmup_poll_interval": "1", "warmup_duration": "3", @@ -35,6 +37,10 @@ def configobj(): "uploads_file": "uploads.txt", "downloads_file": "downloads.txt", } + config["COMPONENTS"] = { + "runner": "globus_compute_slurm_runner", + "transferer": "globus_https_transferer", + } return config diff --git a/src/rjm/tests/test_remote_job_batch.py b/src/rjm/tests/test_remote_job_batch.py index 45c7e030..b2394408 100644 --- a/src/rjm/tests/test_remote_job_batch.py +++ b/src/rjm/tests/test_remote_job_batch.py @@ -13,15 +13,17 @@ @pytest.fixture def configobj(): config = configparser.ConfigParser() - config["GLOBUS"] = { + config["GLOBUS_TRANSFER"] = { "remote_endpoint": "qwerty", "remote_path": "asdfg", } - config["FUNCX"] = { + config["GLOBUS_COMPUTE"] = { "remote_endpoint": "abcdefg", } config["SLURM"] = { "slurm_script": "run.sl", + } + config["POLLING"] = { "poll_interval": "2", "warmup_poll_interval": "1", "warmup_duration": "3", @@ -35,6 +37,10 @@ def configobj(): "uploads_file": "uploads.txt", "downloads_file": "downloads.txt", } + config["COMPONENTS"] = { + "runner": "globus_compute_slurm_runner", + "transferer": "globus_https_transferer", + } return config diff --git a/src/rjm/transferers/globus_https_transferer.py b/src/rjm/transferers/globus_https_transferer.py index 09bc0543..a1cbc49f 100644 --- a/src/rjm/transferers/globus_https_transferer.py +++ b/src/rjm/transferers/globus_https_transferer.py @@ -4,7 +4,6 @@ import time import concurrent.futures import urllib.parse -import hashlib import platform import globus_sdk @@ -18,7 +17,6 @@ DOWNLOAD_CHUNK_SIZE = 8000000 DOWNLOAD_SUFFIX = '.rjm' -FILE_CHUNK_SIZE = 8000000 REQUESTS_TIMEOUT = 30 logger = logging.getLogger(__name__) @@ -34,8 +32,8 @@ def __init__(self, config=None): super(GlobusHttpsTransferer, self).__init__(config=config) # the Globus endpoint for the remote guest collection - self._remote_endpoint = self._config.get("GLOBUS", "remote_endpoint") - self._remote_base_path = self._config.get("GLOBUS", "remote_path") + self._remote_endpoint = self._config.get("GLOBUS_TRANSFER", "remote_endpoint") + self._remote_base_path = self._config.get("GLOBUS_TRANSFER", "remote_path") self._https_scope = utils.HTTPS_SCOPE.format(endpoint_id=self._remote_endpoint) # retry params @@ -64,6 +62,10 @@ def get_globus_scopes(self): return required_scopes + def setup(self, globus_cli, transfer=None): + """Set up the transferer""" + return self.setup_globus_auth(globus_cli, transfer=transfer) + def setup_globus_auth(self, globus_cli, transfer=None): """Setting up Globus authentication.""" if transfer is None: @@ -280,18 +282,6 @@ def _download_file_with_retries(self, filename: str, checksum: str): tries=self._retry_tries, backoff=self._retry_backoff, delay=self._retry_delay, max_delay=self._retry_max_delay) - def _calculate_checksum(self, filename): - """ - Calculate the checksum of the given file - - """ - with open(filename, 'rb') as fh: - checksum = hashlib.sha256() - while chunk := fh.read(FILE_CHUNK_SIZE): - checksum.update(chunk) - - return checksum.hexdigest() - def _download_file(self, filename: str, checksum: str): """ Download a file from remote. diff --git a/src/rjm/transferers/paramiko_sftp_transferer.py b/src/rjm/transferers/paramiko_sftp_transferer.py new file mode 100644 index 00000000..bf8bff55 --- /dev/null +++ b/src/rjm/transferers/paramiko_sftp_transferer.py @@ -0,0 +1,207 @@ + +import os +import stat +import time +import platform +import logging +import paramiko + +from rjm.transferers.transferer_base import TransfererBase +from rjm import utils +from rjm.errors import RemoteJobTransfererError + + +DOWNLOAD_SUFFIX = '.rjm' + + +logger = logging.getLogger(__name__) + + +class ParamikoSftpTransferer(TransfererBase): + """ + Upload and download files to a remote machine using SFTP via + the Paramiko library. + + """ + def __init__(self, config=None): + super(ParamikoSftpTransferer, self).__init__(config=config) + + # config + self._ssh_private_key_file = self._config.get("PARAMIKO", "private_key_file") + self._remote_address = self._config.get("PARAMIKO", "remote_address") + self._remote_user = self._config.get("PARAMIKO", "remote_user") + self._remote_base_path = self._config.get("PARAMIKO", "remote_base_path") + + # retry params + self._retry_tries, self._retry_backoff, self._retry_delay, self._retry_max_delay = utils.get_retry_values_from_config(self._config) + + # TODO: move to a setup function?? + self._private_key = None + self._ssh_client = None + self._sftp_client = None + +# def __del__(self): +# if self._sftp_client is not None: +# self._sftp_client.close() +# if self._ssh_client is not None: +# self._ssh_client.close() + + def _log(self, level, message, *args, **kwargs): + """Add a label to log messages, identifying this specific RemoteJob""" + logger.log(level, self._label + message, *args, **kwargs) + + def setup(self, *args, **kwargs): + """Setup the SFTP client""" + self._log(logging.DEBUG, "Setting up ParamikoSftpTransferer...") + self._ssh_client = paramiko.SSHClient() + self._ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + self._log(logging.DEBUG, f"Loading SSH key from {self._ssh_private_key_file}") + self._private_key = paramiko.RSAKey(filename=self._ssh_private_key_file) + + # Connect to server + self._ssh_client.connect( + hostname=self._remote_address, + port=22, + username=self._remote_user, + pkey=self._private_key, + timeout=30, + look_for_keys=False, + ) + + # Create SFTP client + self._sftp_client = self._ssh_client.open_sftp() + self._log(logging.DEBUG, f"Connected to: {self._remote_address} ({self._sftp_client})") + + # Ensure remote base path exists + # Use ssh_client to run mkdir -p + stdin, stdout, stderr = self._ssh_client.exec_command(f'mkdir -p "{self._remote_base_path}"') + exit_status = stdout.channel.recv_exit_status() + if exit_status != 0: + err = stderr.read().decode() + self._log(logging.ERROR, f'Failed to create remote base path: {self._remote_base_path}. Error: {err}') + else: + self._log(logging.DEBUG, f'Ensured remote base path exists: {self._remote_base_path}') + + def upload_files(self, filenames: list[str]): + """ + Upload the given files to the remote directory. + + :param filenames: List of files to upload to the + remote directory. + :type filenames: iterable of str + + """ + self._log(logging.DEBUG, "Uploading files...") + self._log(logging.DEBUG, f"Remote base path is: {self._remote_base_path}") + self._log(logging.DEBUG, f"Remote path is: {self._remote_path}") + for filename in filenames: + # use basename for remote file name + basename = os.path.basename(filename) + remote_filename = f"{self._remote_base_path}/{self._remote_path}/{basename}" + self._log(logging.DEBUG, f"Uploading: {filename} -> {remote_filename}") + + # upload + start_time = time.perf_counter() + self._sftp_client.put(filename, remote_filename) + upload_time = time.perf_counter() - start_time + self.log_transfer_time("Uploaded", filename, upload_time) + + def download_files(self, filenames, checksums, retries=True): + """ + Download the given files (which should be relative to `remote_path`) to + the local directory. + + :param filenames: list of file names relative to the `remote_path` + directory to download to the local directory. + :param checksums: dictionary with filenames as keys and checksums as + values + :param retries: optional, retry downloads if they fail (default is True) + + """ + errors = 0 + + # check local path exists + if not os.path.exists(self._local_path): + self._log(logging.WARNING, f"Download directory does not exist - creating it ({self._local_path})") + os.makedirs(self._local_path, exist_ok=True) + + # loop over the files to download + self._log(logging.DEBUG, "Downloading files...") + downloaded_tmp_files = [] + for fn in filenames: + self._log(logging.DEBUG, f"Downloading: {fn}") + remote_fn = f"{self._remote_base_path}/{self._remote_path}/{fn}" + + # check it exists first + + + # download to temporary file + local_file_tmp = os.path.join(self._local_path, fn + DOWNLOAD_SUFFIX) + self._log(logging.DEBUG, f"Downloading {fn} to temporary file first: {local_file_tmp}") + if len(local_file_tmp) > 255 and platform.system() == "Windows": + self._log(logging.WARNING, f"Temporary filename is long ({len(local_file_tmp)} characters), may cause problems on Windows") + + # run the download + start_time = time.perf_counter() + try: + self._sftp_client.get(remote_fn, local_file_tmp) + except FileNotFoundError as exc: + errors += 1 + self._log(logging.ERROR, f"File to download is missing: '{fn}' ({exc})") + else: + download_time = time.perf_counter() - start_time + self.log_transfer_time("Downloaded", local_file_tmp, download_time) + + # validate the checksum of the downloaded file + if fn in checksums: + checksum = checksums[fn] + self._log(logging.DEBUG, f"Verifying checksum of \"{local_file_tmp}\"...") + checksum_local = self._calculate_checksum(local_file_tmp) + if checksum != checksum_local: + msg = f"Checksum of downloaded \"{local_file_tmp}\" doesn't match ({checksum_local} vs {checksum})" + self._log(logging.ERROR, msg) + errors += 1 + + downloaded_tmp_files.append(local_file_tmp) + + # at this point we have downloaded to temporary files, now we need to rename them to the actual files + self._log(logging.DEBUG, f"Renaming {len(downloaded_tmp_files)} downloaded temporary files") + start_time = time.perf_counter() + for tmp_file in downloaded_tmp_files: + save_file = tmp_file.removesuffix(DOWNLOAD_SUFFIX) + self._log(logging.DEBUG, f'Renaming "{tmp_file}" -> "{save_file}"') + os.replace(tmp_file, save_file) + rename_time = time.perf_counter() - start_time + self._log(logging.DEBUG, f"Finished renaming files in {rename_time:.1f} s") + + # if there were any errors downloading files, raise an exception now + if errors > 0: + raise RemoteJobTransfererError(f"Failed to download files in '{self._local_path}'") + + self._log(logging.DEBUG, "Finished downloading files") + + def list_directory(self, path: str): + """ + Return a listing of the given directory. + + :param path: Path to the directory + + """ + self._log(logging.DEBUG, f"Listing remote directory: {path}") + + raw_listing = self._sftp_client.listdir_attr(path=path) + + listing = {} + for entry in raw_listing: + mode = entry.st_mode + listing[entry.filename] = { + "permissions": oct(mode)[-3:], + "directory": stat.S_ISDIR(mode), + "size": entry.st_size, + "user": entry.st_uid, + } + + self._log(logging.DEBUG, f"Listing: {listing}") + + return listing diff --git a/src/rjm/transferers/tests/test_globus_https_transferer.py b/src/rjm/transferers/tests/test_globus_https_transferer.py index bbadff40..1070d380 100644 --- a/src/rjm/transferers/tests/test_globus_https_transferer.py +++ b/src/rjm/transferers/tests/test_globus_https_transferer.py @@ -19,15 +19,17 @@ def get_authorization_header(self): @pytest.fixture def configobj(): config = configparser.ConfigParser() - config["GLOBUS"] = { + config["GLOBUS_TRANSFER"] = { "remote_endpoint": "qwerty", "remote_path": "asdfg", } - config["FUNCX"] = { + config["GLOBUS_COMPUTE"] = { "remote_endpoint": "abcdefg", } config["SLURM"] = { "slurm_script": "run.sl", + } + config["POLLING"] = { "poll_interval": "1", } config["RETRY"] = { diff --git a/src/rjm/transferers/transferer_base.py b/src/rjm/transferers/transferer_base.py index 441880cf..05d84f5c 100644 --- a/src/rjm/transferers/transferer_base.py +++ b/src/rjm/transferers/transferer_base.py @@ -1,5 +1,6 @@ import os +import hashlib import logging from typing import List @@ -7,6 +8,8 @@ from rjm import config as config_helper +FILE_CHUNK_SIZE = 8000000 + logger = logging.getLogger(__name__) @@ -46,6 +49,10 @@ def load_state(self, state_dict): if "remote_path" in state_dict: self._remote_path = state_dict["remote_path"] + def setup(self, *args, **kwargs): + """Do any setup required by the specific transferer implementation""" + pass + def get_globus_scopes(self): """If any Globus scopes are required, override this method and return them in a list""" return [] @@ -82,6 +89,18 @@ def log_transfer_time(self, text: str, local_file: str, elapsed_time: float, log self._log(log_level, f"{text} {local_file}: {file_size:.1f} {file_size_units} in {elapsed_time:.1f} s " f"({file_size / elapsed_time:.1f} {file_size_units}/s)") + def _calculate_checksum(self, filename): + """ + Calculate the checksum of the given file + + """ + with open(filename, 'rb') as fh: + checksum = hashlib.sha256() + while chunk := fh.read(FILE_CHUNK_SIZE): + checksum.update(chunk) + + return checksum.hexdigest() + def upload_files(self, filenames: List[str]): """ Upload the given files (which should be relative to `local_path`) to @@ -94,7 +113,7 @@ def upload_files(self, filenames: List[str]): """ raise NotImplementedError - def download_files(self, filenames: List[str]): + def download_files(self, filenames: List[str], *args, **kwargs): """ Download the given files (which should be relative to `remote_path`) to the local directory. diff --git a/src/rjm/utils.py b/src/rjm/utils.py index 3ddb7cb7..86ddf507 100644 --- a/src/rjm/utils.py +++ b/src/rjm/utils.py @@ -58,12 +58,16 @@ def setup_logging(log_name=None, log_file=None, log_level=None, cli_extra=False) # check if specific levels are set in log file if os.path.exists(config_helper.CONFIG_FILE_LOCATION): - config = config_helper.load_config() - if "LOGGING" in config: - for logger_name, level_name in config.items("LOGGING"): - level = getattr(logging, level_name, None) - if level is not None: - logging.getLogger(logger_name).setLevel(level) + try: + config = config_helper.load_config() + except config_helper.RemoteJobConfigError: + pass + else: + if "LOGGING" in config: + for logger_name, level_name in config.items("LOGGING"): + level = getattr(logging, level_name, None) + if level is not None: + logging.getLogger(logger_name).setLevel(level) # command line overrides rjm log level if log_level is not None: @@ -73,9 +77,10 @@ def setup_logging(log_name=None, log_file=None, log_level=None, cli_extra=False) if level is not None: logging.getLogger("rjm").setLevel(level) if cli_extra: - # same level for globus + # same level for globus and paramiko logging.getLogger("globus").setLevel(level) logging.getLogger("globus_compute_sdk").setLevel(level) + logging.getLogger("paramiko").setLevel(level) def handle_globus_auth(scopes, token_file=TOKEN_FILE_LOCATION,