Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ssh host policy #4966

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
7 changes: 3 additions & 4 deletions src/sagemaker/modules/train/container_drivers/mpi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
from __future__ import absolute_import

import os
import time
import subprocess

import time

Check warning on line 18 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L18

Added line #L18 was not covered by tests
from typing import List

from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger

Check warning on line 21 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L21

Added line #L21 was not covered by tests

FINISHED_STATUS_FILE = "/tmp/done.algo-1"
READY_FILE = "/tmp/ready.%s"
Expand Down Expand Up @@ -83,7 +82,7 @@
logger.debug("Testing connection to host %s", host)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.set_missing_host_key_policy(paramiko.RejectPolicy())

Check warning on line 85 in src/sagemaker/modules/train/container_drivers/mpi_utils.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/modules/train/container_drivers/mpi_utils.py#L85

Added line #L85 was not covered by tests
Copy link
Contributor

@benieric benieric Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure for this to work we will need to create some custom policy like:

class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
   def missing_host_key(self, client, hostname, key):
      # if hostname is like `algo-*` autoadd otherwise reject
client = paramiko.SSHClient()
client.set_missing_host_key_policy(CustomHostKeyPolicy)

client.connect(host, port=port)
client.close()
logger.info("Can connect to host %s", host)
Expand Down
Loading