Skip to content

Commit

Permalink
[Fix] Support multi-node distributed training with NPU backend (#1459)
Browse files Browse the repository at this point in the history
  • Loading branch information
shun001 authored Dec 26, 2023
1 parent 671f3bc commit 8e6fb12
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
rank = int(os.environ['RANK'])
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
if is_mlu_available():
import torch_mlu # noqa: F401
local_rank = int(os.environ['LOCAL_RANK'])
torch.mlu.set_device(local_rank)
torch_dist.init_process_group(
backend='cncl',
Expand All @@ -110,15 +111,13 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
**kwargs)
elif is_npu_available():
import torch_npu # noqa: F401
torch.npu.set_device(rank)
torch.npu.set_device(local_rank)
torch_dist.init_process_group(
backend='hccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)

if init_backend == 'torch':
Expand Down

0 comments on commit 8e6fb12

Please sign in to comment.