diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py index d1d19d8f68..4e77700141 100644 --- a/mmengine/dist/utils.py +++ b/mmengine/dist/utils.py @@ -99,9 +99,10 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None: **kwargs: keyword arguments are passed to ``init_process_group``. """ rank = int(os.environ['RANK']) + # LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1 + local_rank = int(os.environ['LOCAL_RANK']) if is_mlu_available(): import torch_mlu # noqa: F401 - local_rank = int(os.environ['LOCAL_RANK']) torch.mlu.set_device(local_rank) torch_dist.init_process_group( backend='cncl', @@ -110,15 +111,13 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None: **kwargs) elif is_npu_available(): import torch_npu # noqa: F401 - torch.npu.set_device(rank) + torch.npu.set_device(local_rank) torch_dist.init_process_group( backend='hccl', rank=rank, world_size=int(os.environ['WORLD_SIZE']), **kwargs) else: - # LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1 - local_rank = int(os.environ['LOCAL_RANK']) torch.cuda.set_device(local_rank) if init_backend == 'torch':