diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index c820a1c48..ba7c8561b 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -322,7 +322,9 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { return true; } -ProcessGroupXCCL::Options::Options() : Backend::Options(XCCL_BACKEND_NAME) {} +ProcessGroupXCCL::Options::Options(bool is_high_priority_stream) + : Backend::Options(XCCL_BACKEND_NAME), + is_high_priority_stream(is_high_priority_stream) {} static std::atomic process_group_id = 0; @@ -351,7 +353,7 @@ const std::string& ProcessGroupXCCL::logPrefix() const { } ProcessGroupXCCL::ProcessGroupXCCL( - const c10::intrusive_ptr& store, + c10::intrusive_ptr store, int rank, int size, c10::intrusive_ptr options) @@ -377,7 +379,10 @@ ProcessGroupXCCL::ProcessGroupXCCL( std::string torch_distributed_debug = getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: " - << "size: " << size << ", global rank: " << rank_; + << "size: " << size << ", global rank: " << rank_ + << ", USE_HIGH_PRIORITY_STREAM: " + << options_->is_high_priority_stream + << ", PG Name: " << options_->group_name; LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: " << "XCCL version: " << XcclVersion @@ -534,9 +539,9 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( rank = p2pRank; } - c10::impl::VirtualGuardImpl impl(device.type()); - c10::Stream stream = - impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); + bool force_high = getCvarBool(TORCH_XCCL_HIGH_PRIORITY, false); + c10::Stream stream = at::xpu::getStreamFromPool( + options_->is_high_priority_stream || force_high); sycl::queue& q = c10::xpu::XPUStream(stream).queue(); auto ctx = ccl::create_context(q.get_context()); diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index e7aa39c82..2516c0b91 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -24,6 +24,9 @@ #include namespace c10d { +static std::vector TORCH_XCCL_HIGH_PRIORITY = { + "TORCH_XCCL_HIGH_PRIORITY"}; + static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; @@ -118,18 +121,19 @@ class TORCH_API ProcessGroupXCCL : public Backend { }; struct Options : public Backend::Options { - explicit Options(); + explicit Options(bool is_high_priority_stream = false); - static c10::intrusive_ptr create() { - return c10::make_intrusive(); + static c10::intrusive_ptr create( + bool is_high_priority_stream = false) { + return c10::make_intrusive(is_high_priority_stream); } - + bool is_high_priority_stream; std::vector global_ranks_in_group; std::string group_name; }; ProcessGroupXCCL( - const c10::intrusive_ptr& store, + c10::intrusive_ptr store, int rank, int size, c10::intrusive_ptr options = Options::create()); @@ -138,11 +142,16 @@ class TORCH_API ProcessGroupXCCL : public Backend { const c10::intrusive_ptr& store, int rank, int size, - const std::string& groupName) - : ProcessGroupXCCL(store, rank, size) {} + const std::string& groupName, + c10::intrusive_ptr options = Options::create()) + : ProcessGroupXCCL(store, rank, size, std::move(options)) {} ~ProcessGroupXCCL() override; + c10::intrusive_ptr getOptions() { + return options_; + } + const std::string getBackendName() const override { return std::string(XCCL_BACKEND_NAME); }