Skip to content

Commit 5621cec

Browse files
Merged upstream torch-compatibility changes in elasticity/elastic_agent.py (#62)
1 parent 6d097be commit 5621cec

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

deepspeed/elasticity/elastic_agent.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
77
from typing import Any, Dict, Optional, Tuple
88
from datetime import datetime
9-
from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port
9+
from torch.distributed.elastic.utils.distributed import get_free_port
1010
from torch.distributed.elastic.metrics import put_metric
1111
from torch.distributed.elastic.agent.server.api import (
1212
RunResult,
@@ -24,6 +24,10 @@
2424
from contextlib import closing
2525
import subprocess
2626

27+
from torch.distributed.elastic.utils.logging import get_logger
28+
29+
log = get_logger(__name__)
30+
2731

2832
class DSElasticAgent(LocalElasticAgent):
2933

@@ -39,9 +43,12 @@ def __init__(
3943
self.ds_env = env
4044

4145
@staticmethod
42-
def _set_master_addr_port(store: Store, master_addr: Optional[str], master_port: Optional[int]):
46+
def _set_master_addr_port(store: Store,
47+
master_addr: Optional[str],
48+
master_port: Optional[int],
49+
local_addr: Optional[str] = None):
4350
if master_port is None:
44-
sock = _get_socket_with_port()
51+
sock = get_free_port()
4552
with closing(sock):
4653
master_port = sock.getsockname()[1]
4754

0 commit comments

Comments
 (0)