diff --git a/tests/integration-tests/configs/tmp_test.yaml b/tests/integration-tests/configs/tmp_test.yaml new file mode 100644 index 0000000000..647f9b0bbd --- /dev/null +++ b/tests/integration-tests/configs/tmp_test.yaml @@ -0,0 +1,13 @@ +{%- import 'common.jinja2' as common with context -%} +{{- common.OSS_COMMERCIAL_ARM.append("centos7") or "" -}} +{{- common.OSS_COMMERCIAL_X86.append("rocky8") or "" -}} +{{- common.OSS_COMMERCIAL_X86.append("rocky9") or "" -}} +--- +test-suites: + proxy: + test_proxy.py::test_proxy: + dimensions: + - regions: ["us-east-1"] + instances: {{ common.INSTANCES_DEFAULT_X86 }} + oss: ["ubuntu2004"] + schedulers: ["slurm"] diff --git a/tests/integration-tests/remote_command_executor.py b/tests/integration-tests/remote_command_executor.py index 8c314fc3a8..e58e8d5529 100644 --- a/tests/integration-tests/remote_command_executor.py +++ b/tests/integration-tests/remote_command_executor.py @@ -29,7 +29,14 @@ class RemoteCommandExecutor: """Execute remote commands on the cluster head node.""" def __init__( - self, cluster, compute_node_ip=None, username=None, bastion=None, alternate_ssh_key=None, use_login_node=False + self, + cluster, + compute_node_ip=None, + username=None, + bastion=None, + alternate_ssh_key=None, + use_login_node=False, + connection_timeout=None, ): """ Initiate SSH connection @@ -61,6 +68,7 @@ def __init__( "host": node_ip, "user": username, "forward_agent": False, + "inline_ssh_env": True, "connect_kwargs": { "key_filename": [alternate_ssh_key if alternate_ssh_key else cluster.ssh_key], "look_for_keys": False, @@ -68,12 +76,18 @@ def __init__( } if bastion: # Need to execute simple ssh command before using Connection to avoid Paramiko _check_banner error - run_command( - f"ssh -i {cluster.ssh_key} -o StrictHostKeyChecking=no {bastion} hostname", timeout=30, shell=True + ssh_command_result = run_command( + f"ssh -i {cluster.ssh_key} -o StrictHostKeyChecking=no {bastion} hostname", + timeout=30, + shell=True, ) + logging.info(f"Command output: {ssh_command_result}") connection_kwargs["gateway"] = f"ssh -W %h:%p -A {bastion}" connection_kwargs["forward_agent"] = True - connection_kwargs["connect_kwargs"]["banner_timeout"] = 60 + connection_kwargs["connect_kwargs"]["banner_timeout"] = 1800 + if connection_timeout: + connection_kwargs["connect_kwargs"]["timeout"] = connection_timeout + logging.info(f"set timeout to {connection_timeout}") logging.info( f"Connecting to {connection_kwargs['host']} as {connection_kwargs['user']} with " f"{connection_kwargs['connect_kwargs']['key_filename']}" diff --git a/tests/integration-tests/tests/proxy/test_proxy.py b/tests/integration-tests/tests/proxy/test_proxy.py index 7e310dab5d..45b998ba1b 100644 --- a/tests/integration-tests/tests/proxy/test_proxy.py +++ b/tests/integration-tests/tests/proxy/test_proxy.py @@ -10,13 +10,14 @@ # This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. # See the License for the specific language governing permissions and limitations under the License. import logging +import os import boto3 import pytest from assertpy import assert_that from cfn_stacks_factory import CfnStack from remote_command_executor import RemoteCommandExecutor -from utils import generate_stack_name +from utils import generate_stack_name, run_command from tests.common.schedulers_common import SlurmCommands @@ -74,7 +75,7 @@ def get_instance_public_ip(instance_id, region): @pytest.mark.usefixtures("region", "os", "instance", "scheduler") -def test_proxy(pcluster_config_reader, clusters_factory, proxy_stack_factory, scheduler_commands_factory): +def test_proxy(pcluster_config_reader, request, proxy_stack_factory, scheduler_commands_factory, clusters_factory): """ Test the creation and functionality of a Cluster using a proxy environment. @@ -84,6 +85,33 @@ def test_proxy(pcluster_config_reader, clusters_factory, proxy_stack_factory, sc 3. Submit a sleep job to the cluster and verify it completes successfully. 4. Check Internet access by trying to access google.com """ + # Start ssh-agent and capture the output + ssh_agent_result = run_command("ssh-agent -s", shell=True) + logging.info(f"SSH agent started with output: {ssh_agent_result.stdout}") + + # Parse the ssh-agent output to set environment variables + for line in ssh_agent_result.stdout.splitlines(): + if line.startswith("SSH_AUTH_SOCK"): + key, value = line.split(";")[0].split("=") + os.environ[key] = value + elif line.startswith("SSH_AGENT_PID"): + key, value = line.split(";")[0].split("=") + os.environ[key] = value + + logging.info("Environment variables are: %s", dict(os.environ)) + + # Verify that the environment variables are set correctly + logging.info(f"SSH_AUTH_SOCK: {os.environ.get('SSH_AUTH_SOCK')}") + logging.info(f"SSH_AGENT_PID: {os.environ.get('SSH_AGENT_PID')}") + + # Add the SSH key using the ssh-add command, passing the environment variables + ssh_add_result = run_command(f'ssh-add {request.config.getoption("key_path")}', shell=True) + logging.info(f"SSH key add result: {ssh_add_result.stderr}") + + # Confirm that the key has been added + added_keys = run_command("ssh-add -l", shell=True) + logging.info(f"SSH keys added: {added_keys.stdout}") + proxy_address = proxy_stack_factory.cfn_outputs["ProxyAddress"] subnet_with_proxy = proxy_stack_factory.cfn_outputs["PrivateSubnet"] proxy_instance_id = proxy_stack_factory.cfn_resources.get("Proxy") @@ -96,21 +124,39 @@ def test_proxy(pcluster_config_reader, clusters_factory, proxy_stack_factory, sc bastion = f"ubuntu@{proxy_public_ip}" - remote_command_executor = RemoteCommandExecutor(cluster=cluster, bastion=bastion) - slurm_commands = SlurmCommands(remote_command_executor) + env_vars = { + "SSH_AUTH_SOCK": os.environ.get("SSH_AUTH_SOCK"), + "SSH_AGENT_PID": os.environ.get("SSH_AGENT_PID"), + } + env_prefix = " && ".join([f"export {key}={value}" for key, value in env_vars.items()]) - _check_internet_access(remote_command_executor) + headnode_instance_ip = cluster.head_node_ip - job_id = slurm_commands.submit_command_and_assert_job_accepted( - submit_command_args={"command": "srun sleep 1", "nodes": 1} + ssh_command_result = run_command( + f"ssh -i {cluster.ssh_key} -o StrictHostKeyChecking=no {bastion} hostname", + timeout=30, + shell=True, ) - slurm_commands.wait_job_completed(job_id) - slurm_commands.assert_job_succeeded(job_id) + logging.info(f"Command output: {ssh_command_result}") + + ssh_gateway_result = run_command(f"ssh -W {headnode_instance_ip}:22 -A {bastion} -vvv", shell=True, raise_on_error=False) + logging.info(f"SSH command output: {ssh_gateway_result}") + + remote_command_executor = RemoteCommandExecutor(cluster=cluster, bastion=bastion, connection_timeout=300) + # slurm_commands = SlurmCommands(remote_command_executor) + + _check_internet_access(remote_command_executor, env_prefix) + + # job_id = slurm_commands.submit_command_and_assert_job_accepted( + # submit_command_args={"command": "srun sleep 1", "nodes": 1} + # ) + # slurm_commands.wait_job_completed(job_id) + # slurm_commands.assert_job_succeeded(job_id) -def _check_internet_access(remote_command_executor): +def _check_internet_access(remote_command_executor, env_prefix): logging.info("Checking cluster has Internet access by trying to access google.com") internet_result = remote_command_executor.run_remote_command( - "curl --connect-timeout 10 -I https://google.com", raise_on_error=False + f"{env_prefix} && curl --connect-timeout 10 -I https://google.com", raise_on_error=False ) assert_that(internet_result.failed).is_false() diff --git a/tests/integration-tests/tests/proxy/test_proxy/test_proxy/pcluster.config.yaml b/tests/integration-tests/tests/proxy/test_proxy/test_proxy/pcluster.config.yaml index 4ab0a63a59..0e6e722ab2 100644 --- a/tests/integration-tests/tests/proxy/test_proxy/test_proxy/pcluster.config.yaml +++ b/tests/integration-tests/tests/proxy/test_proxy/test_proxy/pcluster.config.yaml @@ -4,6 +4,7 @@ HeadNode: InstanceType: {{ instance }} Ssh: KeyName: {{ key_name }} + AllowedIps: 0.0.0.0/0 Networking: SubnetId: {{ subnet_with_proxy }} Proxy: diff --git a/tests/integration-tests/utils.py b/tests/integration-tests/utils.py index 9a92998c85..46159ec8fb 100644 --- a/tests/integration-tests/utils.py +++ b/tests/integration-tests/utils.py @@ -187,6 +187,9 @@ def run_command( command = shlex.split(command) log_command = command if isinstance(command, str) else " ".join(str(arg) for arg in command) logging.info("Executing command: {}".format(log_command)) + + env = env if env is not None else os.environ.copy() + try: result = subprocess.run( command,