diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 514cdaf3a7..aca5a928aa 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -352,6 +352,11 @@ const std::string& ProcessGroupXCCL::logPrefix() const { return logPrefix_; } +const int& ProcessGroupXCCL::globalRank() const { + static int globalRank = rank_; + return globalRank; +} + ProcessGroupXCCL::ProcessGroupXCCL( c10::intrusive_ptr store, int rank, @@ -379,7 +384,7 @@ ProcessGroupXCCL::ProcessGroupXCCL( std::string torch_distributed_debug = getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: " - << "size: " << size << ", global rank: " << rank_ + << "size: " << size << ", global rank: " << globalRank() << ", USE_HIGH_PRIORITY_STREAM: " << options_->is_high_priority_stream << ", PG Name: " << options_->group_name; @@ -410,7 +415,7 @@ bool ProcessGroupXCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) { if (traceBufferSize_ > 0) { // TODO: dump_xccl_trace auto xcclTrace = dump_xccl_trace(true, includeStackTrace, false); - DebugInfoWriter& writer = DebugInfoWriter::getWriter(rank_); + DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); LOG(INFO) << logPrefix() << "ProcessGroupXCCL dumping xccl trace to " << writer.getWriterTarget(); writer.write(xcclTrace); @@ -2021,7 +2026,7 @@ c10::DeviceIndex ProcessGroupXCCL::guessDeviceId() const { return *usedDeviceIdxs_.begin(); } int devIdx = - static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); + static_cast(globalRank() % at::detail::getXPUHooks().getNumGPUs()); LOG(WARNING) << logPrefix() << c10::str( diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 8e4604fbfb..7074354c64 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -423,6 +423,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::DeviceIndex guessDeviceId() const; const std::vector& groupRanks() const; + const int& globalRank() const; void setEnqueuedPgStatus(c10::intrusive_ptr work); bool dumpDebuggingInfo(bool includeStackTrace = true); diff --git a/src/xccl/ProcessGroupXCCLMonitor.cpp b/src/xccl/ProcessGroupXCCLMonitor.cpp index 68fdd402c1..246bdf3a59 100644 --- a/src/xccl/ProcessGroupXCCLMonitor.cpp +++ b/src/xccl/ProcessGroupXCCLMonitor.cpp @@ -39,7 +39,7 @@ void HeartbeatMonitorXCCL::runLoop() { // We only need to dump once per PG, so we use local_id_ == 0 for the first PG if (pg_->local_id_ == 0) { // DumpPipe is one per-trainer process - dumpPipe.emplace(pg_->getRank()); + dumpPipe.emplace(pg_->globalRank()); while (true) { std::unique_lock lock(monitorMutex_); if (monitorWakeUpCV_.wait_for(