diff --git a/comms/torchcomms/BackendWrapper.cpp b/comms/torchcomms/BackendWrapper.cpp index 6c8a4f2e..6465f8df 100644 --- a/comms/torchcomms/BackendWrapper.cpp +++ b/comms/torchcomms/BackendWrapper.cpp @@ -51,7 +51,7 @@ std::vector toVecUint64(const std::vector& vec) { } // namespace -WorkWrapper::WorkWrapper(std::shared_ptr work) +WorkWrapper::WorkWrapper(c10::intrusive_ptr work) : work_(std::move(work)) {} bool WorkWrapper::isCompleted() { diff --git a/comms/torchcomms/BackendWrapper.hpp b/comms/torchcomms/BackendWrapper.hpp index 30bd3429..061c83df 100644 --- a/comms/torchcomms/BackendWrapper.hpp +++ b/comms/torchcomms/BackendWrapper.hpp @@ -12,7 +12,7 @@ namespace comms { class WorkWrapper : public c10d::Work { public: - explicit WorkWrapper(std::shared_ptr work); + explicit WorkWrapper(c10::intrusive_ptr work); ~WorkWrapper() override = default; bool isCompleted() override; @@ -23,7 +23,7 @@ class WorkWrapper : public c10d::Work { std::vector result() override; private: - std::shared_ptr work_; + c10::intrusive_ptr work_; }; using c10d::kUnsetTimeout; diff --git a/comms/torchcomms/TorchComm.cpp b/comms/torchcomms/TorchComm.cpp index 7a886be4..53e54bdb 100644 --- a/comms/torchcomms/TorchComm.cpp +++ b/comms/torchcomms/TorchComm.cpp @@ -39,7 +39,7 @@ std::string_view TorchComm::getCommName() const { } // Point-to-Point Operations -std::shared_ptr TorchComm::send( +c10::intrusive_ptr TorchComm::send( const at::Tensor& tensor, int dst, bool async_op, @@ -47,7 +47,7 @@ std::shared_ptr TorchComm::send( return impl_->send(tensor, dst, async_op, options); } -std::shared_ptr TorchComm::recv( +c10::intrusive_ptr TorchComm::recv( at::Tensor& tensor, int src, bool async_op, @@ -56,7 +56,7 @@ std::shared_ptr TorchComm::recv( } // Collective Operations -std::shared_ptr TorchComm::broadcast( +c10::intrusive_ptr TorchComm::broadcast( at::Tensor& tensor, int root, bool async_op, @@ -64,7 +64,7 @@ std::shared_ptr TorchComm::broadcast( return impl_->broadcast(tensor, root, async_op, options); } -std::shared_ptr TorchComm::all_reduce( +c10::intrusive_ptr TorchComm::all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, @@ -72,7 +72,7 @@ std::shared_ptr TorchComm::all_reduce( return impl_->all_reduce(tensor, op, async_op, options); } -std::shared_ptr TorchComm::reduce( +c10::intrusive_ptr TorchComm::reduce( const at::Tensor& tensor, int root, ReduceOp op, @@ -81,7 +81,7 @@ std::shared_ptr TorchComm::reduce( return impl_->reduce(tensor, root, op, async_op, options); } -std::shared_ptr TorchComm::all_gather( +c10::intrusive_ptr TorchComm::all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -89,7 +89,7 @@ std::shared_ptr TorchComm::all_gather( return impl_->all_gather(tensor_list, tensor, async_op, options); } -std::shared_ptr TorchComm::all_gather_v( +c10::intrusive_ptr TorchComm::all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -97,7 +97,7 @@ std::shared_ptr TorchComm::all_gather_v( return impl_->all_gather_v(tensor_list, tensor, async_op, options); } -std::shared_ptr TorchComm::all_gather_single( +c10::intrusive_ptr TorchComm::all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -105,7 +105,7 @@ std::shared_ptr TorchComm::all_gather_single( return impl_->all_gather_single(output, input, async_op, options); } -std::shared_ptr TorchComm::reduce_scatter( +c10::intrusive_ptr TorchComm::reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -114,7 +114,7 @@ std::shared_ptr TorchComm::reduce_scatter( return impl_->reduce_scatter(output, input_list, op, async_op, options); } -std::shared_ptr TorchComm::reduce_scatter_v( +c10::intrusive_ptr TorchComm::reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -123,7 +123,7 @@ std::shared_ptr TorchComm::reduce_scatter_v( return impl_->reduce_scatter_v(output, input_list, op, async_op, options); } -std::shared_ptr TorchComm::reduce_scatter_single( +c10::intrusive_ptr TorchComm::reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, @@ -132,7 +132,7 @@ std::shared_ptr TorchComm::reduce_scatter_single( return impl_->reduce_scatter_single(output, input, op, async_op, options); } -std::shared_ptr TorchComm::all_to_all_single( +c10::intrusive_ptr TorchComm::all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -140,7 +140,7 @@ std::shared_ptr TorchComm::all_to_all_single( return impl_->all_to_all_single(output, input, async_op, options); } -std::shared_ptr TorchComm::all_to_all_v_single( +c10::intrusive_ptr TorchComm::all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, @@ -151,7 +151,7 @@ std::shared_ptr TorchComm::all_to_all_v_single( output, input, output_split_sizes, input_split_sizes, async_op, options); } -std::shared_ptr TorchComm::all_to_all( +c10::intrusive_ptr TorchComm::all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, @@ -160,14 +160,14 @@ std::shared_ptr TorchComm::all_to_all( output_tensor_list, input_tensor_list, async_op, options); } -std::shared_ptr TorchComm::barrier( +c10::intrusive_ptr TorchComm::barrier( bool async_op, const BarrierOptions& options) { return impl_->barrier(async_op, options); } // Scatter and Gather Operations -std::shared_ptr TorchComm::scatter( +c10::intrusive_ptr TorchComm::scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, @@ -177,7 +177,7 @@ std::shared_ptr TorchComm::scatter( output_tensor, input_tensor_list, root, async_op, options); } -std::shared_ptr TorchComm::gather( +c10::intrusive_ptr TorchComm::gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, @@ -239,7 +239,7 @@ void BatchSendRecv::recv(at::Tensor& tensor, int src) { ops.push_back(op); } -std::shared_ptr BatchSendRecv::issue( +c10::intrusive_ptr BatchSendRecv::issue( bool async_op, const BatchP2POptions& options) { return parent_->getBackendImpl()->batch_op_issue(ops, async_op, options); diff --git a/comms/torchcomms/TorchComm.hpp b/comms/torchcomms/TorchComm.hpp index e68e8369..ee368371 100644 --- a/comms/torchcomms/TorchComm.hpp +++ b/comms/torchcomms/TorchComm.hpp @@ -32,96 +32,96 @@ class TorchComm { std::string_view getCommName() const; // Point-to-Point Operations - std::shared_ptr send( + c10::intrusive_ptr send( const at::Tensor& tensor, int dst, bool async_op, const SendOptions& options = {}); - std::shared_ptr recv( + c10::intrusive_ptr recv( at::Tensor& tensor, int src, bool async_op, const RecvOptions& options = {}); // Collective Operations - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( at::Tensor& tensor, int root, bool async_op, const BroadcastOptions& options = {}); - std::shared_ptr all_reduce( + c10::intrusive_ptr all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, const AllReduceOptions& options = {}); - std::shared_ptr reduce( + c10::intrusive_ptr reduce( const at::Tensor& tensor, int root, ReduceOp op, bool async_op, const ReduceOptions& options = {}); - std::shared_ptr all_gather( + c10::intrusive_ptr all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}); - std::shared_ptr all_gather_v( + c10::intrusive_ptr all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}); - std::shared_ptr all_gather_single( + c10::intrusive_ptr all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllGatherSingleOptions& options = {}); - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}); - std::shared_ptr reduce_scatter_v( + c10::intrusive_ptr reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}); - std::shared_ptr reduce_scatter_single( + c10::intrusive_ptr reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, bool async_op, const ReduceScatterSingleOptions& options = {}); - std::shared_ptr all_to_all_single( + c10::intrusive_ptr all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllToAllSingleOptions& options = {}); - std::shared_ptr all_to_all_v_single( + c10::intrusive_ptr all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, const std::vector& input_split_sizes, bool async_op, const AllToAllvSingleOptions& options = {}); - std::shared_ptr all_to_all( + c10::intrusive_ptr all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, const AllToAllOptions& options = {}); - std::shared_ptr barrier( + c10::intrusive_ptr barrier( bool async_op, const BarrierOptions& options = {}); // Scatter and Gather Operations - std::shared_ptr scatter( + c10::intrusive_ptr scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, bool async_op, const ScatterOptions& options = {}); - std::shared_ptr gather( + c10::intrusive_ptr gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, diff --git a/comms/torchcomms/TorchCommBackend.hpp b/comms/torchcomms/TorchCommBackend.hpp index 896631fb..25f3f85d 100644 --- a/comms/torchcomms/TorchCommBackend.hpp +++ b/comms/torchcomms/TorchCommBackend.hpp @@ -40,101 +40,101 @@ class TorchCommBackend { virtual std::string_view getCommName() const = 0; // Point-to-Point Operations - virtual std::shared_ptr send( + virtual c10::intrusive_ptr send( const at::Tensor& tensor, int dst, bool async_op, const SendOptions& options = {}) = 0; - virtual std::shared_ptr recv( + virtual c10::intrusive_ptr recv( at::Tensor& tensor, int src, bool async_op, const RecvOptions& options = {}) = 0; - virtual std::shared_ptr batch_op_issue( + virtual c10::intrusive_ptr batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options = {}) = 0; // Collective Operations - virtual std::shared_ptr broadcast( + virtual c10::intrusive_ptr broadcast( at::Tensor& tensor, int root, bool async_op, const BroadcastOptions& options = {}) = 0; - virtual std::shared_ptr all_reduce( + virtual c10::intrusive_ptr all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, const AllReduceOptions& options = {}) = 0; - virtual std::shared_ptr reduce( + virtual c10::intrusive_ptr reduce( const at::Tensor& tensor, int root, ReduceOp op, bool async_op, const ReduceOptions& options = {}) = 0; - virtual std::shared_ptr all_gather( + virtual c10::intrusive_ptr all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) = 0; - virtual std::shared_ptr all_gather_v( + virtual c10::intrusive_ptr all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) = 0; - virtual std::shared_ptr all_gather_single( + virtual c10::intrusive_ptr all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllGatherSingleOptions& options = {}) = 0; - virtual std::shared_ptr reduce_scatter( + virtual c10::intrusive_ptr reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) = 0; - virtual std::shared_ptr reduce_scatter_v( + virtual c10::intrusive_ptr reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) = 0; - virtual std::shared_ptr reduce_scatter_single( + virtual c10::intrusive_ptr reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, bool async_op, const ReduceScatterSingleOptions& options = {}) = 0; - virtual std::shared_ptr all_to_all_single( + virtual c10::intrusive_ptr all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllToAllSingleOptions& options = {}) = 0; - virtual std::shared_ptr all_to_all_v_single( + virtual c10::intrusive_ptr all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, const std::vector& input_split_sizes, bool async_op, const AllToAllvSingleOptions& options = {}) = 0; - virtual std::shared_ptr all_to_all( + virtual c10::intrusive_ptr all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, const AllToAllOptions& options = {}) = 0; - virtual std::shared_ptr barrier( + virtual c10::intrusive_ptr barrier( bool async_op, const BarrierOptions& options = {}) = 0; // Scatter and Gather Operations - virtual std::shared_ptr scatter( + virtual c10::intrusive_ptr scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, bool async_op, const ScatterOptions& options = {}) = 0; - virtual std::shared_ptr gather( + virtual c10::intrusive_ptr gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, diff --git a/comms/torchcomms/TorchCommBatch.hpp b/comms/torchcomms/TorchCommBatch.hpp index 094b38fa..ff6c006e 100644 --- a/comms/torchcomms/TorchCommBatch.hpp +++ b/comms/torchcomms/TorchCommBatch.hpp @@ -27,7 +27,7 @@ class BatchSendRecv { void send(const at::Tensor& tensor, int dst); void recv(at::Tensor& tensor, int src); - std::shared_ptr issue( + c10::intrusive_ptr issue( bool async_op, const BatchP2POptions& options = {}); diff --git a/comms/torchcomms/TorchCommPy.cpp b/comms/torchcomms/TorchCommPy.cpp index a4b7d452..bb5e656b 100644 --- a/comms/torchcomms/TorchCommPy.cpp +++ b/comms/torchcomms/TorchCommPy.cpp @@ -97,7 +97,7 @@ PYBIND11_MODULE(_comms, m) { .def_readwrite("timeout", &BatchP2POptions::timeout, "Timeout"); // Bind TorchWork class - py::class_>( + intrusive_ptr_class_( m, "TorchWork", R"( diff --git a/comms/torchcomms/TorchCommWindow.hpp b/comms/torchcomms/TorchCommWindow.hpp index 0d645117..a397b9b3 100644 --- a/comms/torchcomms/TorchCommWindow.hpp +++ b/comms/torchcomms/TorchCommWindow.hpp @@ -26,16 +26,16 @@ class TorchCommWindow { const size_t window_size, bool cpu_buf = false, const size_t signal_size = 256) = 0; - virtual std::shared_ptr + virtual c10::intrusive_ptr put(const at::Tensor& data, int dstRank, size_t targetDisp, bool asyncOp) = 0; virtual at::Tensor getTensor( int rank, at::IntArrayRef sizes, at::ScalarType dtype, int64_t storageOffset) = 0; - virtual std::shared_ptr + virtual c10::intrusive_ptr signal(size_t signalDisp, uint64_t signalVal, int dstRank, bool asyncOp) = 0; - virtual std::shared_ptr waitSignal( + virtual c10::intrusive_ptr waitSignal( size_t signalDisp, uint64_t cmpVal, SignalCmpOp cmpOp, diff --git a/comms/torchcomms/TorchWork.hpp b/comms/torchcomms/TorchWork.hpp index 3566710d..7550ec75 100644 --- a/comms/torchcomms/TorchWork.hpp +++ b/comms/torchcomms/TorchWork.hpp @@ -2,13 +2,14 @@ #pragma once +#include #include #include namespace torch { namespace comms { -class TorchWork { +class TorchWork : public c10::intrusive_ptr_target { public: TorchWork() = default; virtual ~TorchWork() = default; diff --git a/comms/torchcomms/examples/SendRecvAsync.cpp b/comms/torchcomms/examples/SendRecvAsync.cpp index 05c8d5ec..999d3df2 100644 --- a/comms/torchcomms/examples/SendRecvAsync.cpp +++ b/comms/torchcomms/examples/SendRecvAsync.cpp @@ -54,8 +54,8 @@ int main() { // Perform asynchronous send/recv operations // Use alternating pattern to avoid deadlock - std::shared_ptr send_work = nullptr; - std::shared_ptr recv_work = nullptr; + c10::intrusive_ptr send_work = nullptr; + c10::intrusive_ptr recv_work = nullptr; if (rank % 2 == 0) { // Even ranks: send first, then receive diff --git a/comms/torchcomms/gloo/TorchCommGloo.cpp b/comms/torchcomms/gloo/TorchCommGloo.cpp index ca40bc3f..6c712373 100644 --- a/comms/torchcomms/gloo/TorchCommGloo.cpp +++ b/comms/torchcomms/gloo/TorchCommGloo.cpp @@ -373,19 +373,19 @@ std::string_view TorchCommGloo::getBackendName() const { std::string_view TorchCommGloo::getCommName() const { return name_; } -std::shared_ptr TorchCommGloo::createWork( +c10::intrusive_ptr TorchCommGloo::createWork( std::function fn, bool async_op) { if (async_op) { - return std::make_shared(std::move(fn)); + return c10::make_intrusive(std::move(fn)); } fn(); - return std::make_shared(); + return c10::make_intrusive(); } // Point-to-Point Operations -std::shared_ptr TorchCommGloo::send( +c10::intrusive_ptr TorchCommGloo::send( const at::Tensor& tensor, int dst, bool async_op, @@ -424,7 +424,7 @@ std::shared_ptr TorchCommGloo::send( async_op); } -std::shared_ptr TorchCommGloo::recv( +c10::intrusive_ptr TorchCommGloo::recv( at::Tensor& tensor, int src, bool async_op, @@ -474,7 +474,7 @@ std::shared_ptr TorchCommGloo::recv( } // Batch P2P Operations -std::shared_ptr TorchCommGloo::batch_op_issue( +c10::intrusive_ptr TorchCommGloo::batch_op_issue( const std::vector& ops, bool /*async_op*/, const BatchP2POptions& /*options*/) { @@ -510,7 +510,7 @@ std::shared_ptr TorchCommGloo::batch_op_issue( } // Collective Operations -std::shared_ptr TorchCommGloo::broadcast( +c10::intrusive_ptr TorchCommGloo::broadcast( at::Tensor& tensor, int root, bool async_op, @@ -551,7 +551,7 @@ std::shared_ptr TorchCommGloo::broadcast( async_op); } -std::shared_ptr TorchCommGloo::all_reduce( +c10::intrusive_ptr TorchCommGloo::all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, @@ -594,7 +594,7 @@ std::shared_ptr TorchCommGloo::all_reduce( async_op); } -std::shared_ptr TorchCommGloo::reduce( +c10::intrusive_ptr TorchCommGloo::reduce( const at::Tensor& tensor, int root, ReduceOp op, @@ -640,7 +640,7 @@ std::shared_ptr TorchCommGloo::reduce( async_op); } -std::shared_ptr TorchCommGloo::all_gather( +c10::intrusive_ptr TorchCommGloo::all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -719,7 +719,7 @@ std::shared_ptr TorchCommGloo::all_gather( async_op); } -std::shared_ptr TorchCommGloo::all_gather_v( +c10::intrusive_ptr TorchCommGloo::all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -727,7 +727,7 @@ std::shared_ptr TorchCommGloo::all_gather_v( throw std::runtime_error("all_gather_v is not supported in GLOO backend yet"); } -std::shared_ptr TorchCommGloo::all_gather_single( +c10::intrusive_ptr TorchCommGloo::all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -779,7 +779,7 @@ std::shared_ptr TorchCommGloo::all_gather_single( async_op); } -std::shared_ptr TorchCommGloo::reduce_scatter( +c10::intrusive_ptr TorchCommGloo::reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -815,7 +815,7 @@ std::shared_ptr TorchCommGloo::reduce_scatter( return reduce_scatter_single(output, input, op, async_op, singleOptions); } -std::shared_ptr TorchCommGloo::reduce_scatter_v( +c10::intrusive_ptr TorchCommGloo::reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -825,7 +825,7 @@ std::shared_ptr TorchCommGloo::reduce_scatter_v( "reduce_scatter_v is not supported in GLOO backend yet"); } -std::shared_ptr TorchCommGloo::reduce_scatter_single( +c10::intrusive_ptr TorchCommGloo::reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, @@ -885,7 +885,7 @@ std::shared_ptr TorchCommGloo::reduce_scatter_single( async_op); } -std::shared_ptr TorchCommGloo::all_to_all_single( +c10::intrusive_ptr TorchCommGloo::all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -942,7 +942,7 @@ std::shared_ptr TorchCommGloo::all_to_all_single( async_op); } -std::shared_ptr TorchCommGloo::all_to_all_v_single( +c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, @@ -1039,7 +1039,7 @@ std::shared_ptr TorchCommGloo::all_to_all_v_single( async_op); } -std::shared_ptr TorchCommGloo::all_to_all( +c10::intrusive_ptr TorchCommGloo::all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, @@ -1122,7 +1122,7 @@ std::shared_ptr TorchCommGloo::all_to_all( async_op); } -std::shared_ptr TorchCommGloo::barrier( +c10::intrusive_ptr TorchCommGloo::barrier( bool async_op, const BarrierOptions& options) { checkInitialized(); @@ -1142,7 +1142,7 @@ std::shared_ptr TorchCommGloo::barrier( async_op); } -std::shared_ptr TorchCommGloo::scatter( +c10::intrusive_ptr TorchCommGloo::scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, @@ -1219,7 +1219,7 @@ std::shared_ptr TorchCommGloo::scatter( async_op); } -std::shared_ptr TorchCommGloo::gather( +c10::intrusive_ptr TorchCommGloo::gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, diff --git a/comms/torchcomms/gloo/TorchCommGloo.hpp b/comms/torchcomms/gloo/TorchCommGloo.hpp index 1a5d5cd9..f3cc1366 100644 --- a/comms/torchcomms/gloo/TorchCommGloo.hpp +++ b/comms/torchcomms/gloo/TorchCommGloo.hpp @@ -51,102 +51,102 @@ class TorchCommGloo : public TorchCommBackend, std::string_view getCommName() const override; // Point-to-Point Operations - std::shared_ptr send( + c10::intrusive_ptr send( const at::Tensor& tensor, int dst, bool async_op, const SendOptions& options = {}) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( at::Tensor& tensor, int src, bool async_op, const RecvOptions& options = {}) override; // Batch P2P Operations - std::shared_ptr batch_op_issue( + c10::intrusive_ptr batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options = {}) override; // Collective Operations - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( at::Tensor& tensor, int root, bool async_op, const BroadcastOptions& options = {}) override; - std::shared_ptr all_reduce( + c10::intrusive_ptr all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, const AllReduceOptions& options = {}) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( const at::Tensor& tensor, int root, ReduceOp op, bool async_op, const ReduceOptions& options = {}) override; - std::shared_ptr all_gather( + c10::intrusive_ptr all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_v( + c10::intrusive_ptr all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_single( + c10::intrusive_ptr all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllGatherSingleOptions& options = {}) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_v( + c10::intrusive_ptr reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_single( + c10::intrusive_ptr reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, bool async_op, const ReduceScatterSingleOptions& options = {}) override; - std::shared_ptr all_to_all_single( + c10::intrusive_ptr all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllToAllSingleOptions& options = {}) override; - std::shared_ptr all_to_all_v_single( + c10::intrusive_ptr all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, const std::vector& input_split_sizes, bool async_op, const AllToAllvSingleOptions& options = {}) override; - std::shared_ptr all_to_all( + c10::intrusive_ptr all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, const AllToAllOptions& options = {}) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( bool async_op, const BarrierOptions& options = {}) override; // Scatter and Gather Operations - std::shared_ptr scatter( + c10::intrusive_ptr scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, bool async_op, const ScatterOptions& options = {}) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, @@ -182,7 +182,7 @@ class TorchCommGloo : public TorchCommBackend, CommState::NORMAL}; // State of the communicator private: - std::shared_ptr createWork( + c10::intrusive_ptr createWork( std::function fn, bool async_op); diff --git a/comms/torchcomms/nccl/TorchCommNCCL.cpp b/comms/torchcomms/nccl/TorchCommNCCL.cpp index 40155647..da1df5d1 100644 --- a/comms/torchcomms/nccl/TorchCommNCCL.cpp +++ b/comms/torchcomms/nccl/TorchCommNCCL.cpp @@ -331,7 +331,7 @@ static inline std::chrono::milliseconds getOperationTimeout( } // Point-to-Point Operations -std::shared_ptr TorchCommNCCL::send( +c10::intrusive_ptr TorchCommNCCL::send( const at::Tensor& tensor, int dst, bool async_op, @@ -370,7 +370,7 @@ std::shared_ptr TorchCommNCCL::send( return work; } -std::shared_ptr TorchCommNCCL::recv( +c10::intrusive_ptr TorchCommNCCL::recv( at::Tensor& tensor, int src, bool async_op, @@ -410,7 +410,7 @@ std::shared_ptr TorchCommNCCL::recv( } // Batch P2P Operations -std::shared_ptr TorchCommNCCL::batch_op_issue( +c10::intrusive_ptr TorchCommNCCL::batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options) { @@ -506,7 +506,7 @@ std::shared_ptr TorchCommNCCL::batch_op_issue( } // Collective Operations -std::shared_ptr TorchCommNCCL::broadcast( +c10::intrusive_ptr TorchCommNCCL::broadcast( at::Tensor& tensor, int root, bool async_op, @@ -546,7 +546,7 @@ std::shared_ptr TorchCommNCCL::broadcast( return work; } -std::shared_ptr TorchCommNCCL::all_reduce( +c10::intrusive_ptr TorchCommNCCL::all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, @@ -587,7 +587,7 @@ std::shared_ptr TorchCommNCCL::all_reduce( return work; } -std::shared_ptr TorchCommNCCL::reduce( +c10::intrusive_ptr TorchCommNCCL::reduce( const at::Tensor& tensor, int root, ReduceOp op, @@ -634,7 +634,7 @@ std::shared_ptr TorchCommNCCL::reduce( return work; } -std::shared_ptr TorchCommNCCL::all_gather( +c10::intrusive_ptr TorchCommNCCL::all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -691,7 +691,7 @@ std::shared_ptr TorchCommNCCL::all_gather( return work; } -std::shared_ptr TorchCommNCCL::all_gather_v( +c10::intrusive_ptr TorchCommNCCL::all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -699,7 +699,7 @@ std::shared_ptr TorchCommNCCL::all_gather_v( throw std::runtime_error("all_gather_v is not supported in NCCL backend"); } -std::shared_ptr TorchCommNCCL::all_gather_single( +c10::intrusive_ptr TorchCommNCCL::all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -743,7 +743,7 @@ std::shared_ptr TorchCommNCCL::all_gather_single( return work; } -std::shared_ptr TorchCommNCCL::reduce_scatter( +c10::intrusive_ptr TorchCommNCCL::reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -818,7 +818,7 @@ std::shared_ptr TorchCommNCCL::reduce_scatter( return work; } -std::shared_ptr TorchCommNCCL::reduce_scatter_v( +c10::intrusive_ptr TorchCommNCCL::reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -827,7 +827,7 @@ std::shared_ptr TorchCommNCCL::reduce_scatter_v( throw std::runtime_error("reduce_scatter_v is not supported in NCCL backend"); } -std::shared_ptr TorchCommNCCL::reduce_scatter_single( +c10::intrusive_ptr TorchCommNCCL::reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, @@ -876,7 +876,7 @@ std::shared_ptr TorchCommNCCL::reduce_scatter_single( return work; } -std::shared_ptr TorchCommNCCL::all_to_all_single( +c10::intrusive_ptr TorchCommNCCL::all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -948,7 +948,7 @@ std::shared_ptr TorchCommNCCL::all_to_all_single( return work; } -std::shared_ptr TorchCommNCCL::all_to_all_v_single( +c10::intrusive_ptr TorchCommNCCL::all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, @@ -1038,7 +1038,7 @@ std::shared_ptr TorchCommNCCL::all_to_all_v_single( return work; } -std::shared_ptr TorchCommNCCL::all_to_all( +c10::intrusive_ptr TorchCommNCCL::all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, @@ -1102,7 +1102,7 @@ std::shared_ptr TorchCommNCCL::all_to_all( return work; } -std::shared_ptr TorchCommNCCL::barrier( +c10::intrusive_ptr TorchCommNCCL::barrier( bool async_op, const BarrierOptions& options) { checkInitialized(); @@ -1139,7 +1139,7 @@ std::shared_ptr TorchCommNCCL::barrier( return work; } -std::shared_ptr TorchCommNCCL::scatter( +c10::intrusive_ptr TorchCommNCCL::scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, @@ -1229,7 +1229,7 @@ std::shared_ptr TorchCommNCCL::scatter( return work; } -std::shared_ptr TorchCommNCCL::gather( +c10::intrusive_ptr TorchCommNCCL::gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, diff --git a/comms/torchcomms/nccl/TorchCommNCCL.hpp b/comms/torchcomms/nccl/TorchCommNCCL.hpp index 8b8b5140..4e8d15bc 100644 --- a/comms/torchcomms/nccl/TorchCommNCCL.hpp +++ b/comms/torchcomms/nccl/TorchCommNCCL.hpp @@ -68,102 +68,102 @@ class TorchCommNCCL : public TorchCommBackend, std::string_view getCommName() const override; // Point-to-Point Operations - std::shared_ptr send( + c10::intrusive_ptr send( const at::Tensor& tensor, int dst, bool async_op, const SendOptions& options = {}) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( at::Tensor& tensor, int src, bool async_op, const RecvOptions& options = {}) override; // Batch P2P Operations - std::shared_ptr batch_op_issue( + c10::intrusive_ptr batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options = {}) override; // Collective Operations - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( at::Tensor& tensor, int root, bool async_op, const BroadcastOptions& options = {}) override; - std::shared_ptr all_reduce( + c10::intrusive_ptr all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, const AllReduceOptions& options = {}) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( const at::Tensor& tensor, int root, ReduceOp op, bool async_op, const ReduceOptions& options = {}) override; - std::shared_ptr all_gather( + c10::intrusive_ptr all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_v( + c10::intrusive_ptr all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_single( + c10::intrusive_ptr all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllGatherSingleOptions& options = {}) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_v( + c10::intrusive_ptr reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_single( + c10::intrusive_ptr reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, bool async_op, const ReduceScatterSingleOptions& options = {}) override; - std::shared_ptr all_to_all_single( + c10::intrusive_ptr all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllToAllSingleOptions& options = {}) override; - std::shared_ptr all_to_all_v_single( + c10::intrusive_ptr all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, const std::vector& input_split_sizes, bool async_op, const AllToAllvSingleOptions& options = {}) override; - std::shared_ptr all_to_all( + c10::intrusive_ptr all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, const AllToAllOptions& options = {}) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( bool async_op, const BarrierOptions& options = {}) override; // Scatter and Gather Operations - std::shared_ptr scatter( + c10::intrusive_ptr scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, bool async_op, const ScatterOptions& options = {}) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, @@ -236,7 +236,7 @@ class TorchCommNCCL : public TorchCommBackend, void register_address(const AddressWithLen& addr); void deregister_address(const Address& addr); ncclDataType_t getNcclDataType(const at::Tensor& tensor); - std::shared_ptr createWork( + c10::intrusive_ptr createWork( cudaStream_t stream, std::chrono::milliseconds timeout, const std::vector& inputTensors); @@ -300,7 +300,7 @@ class TorchCommNCCL : public TorchCommBackend, void checkInitialized() const; void checkAndAbortIfTimedOutOrError(); void checkWorkQueue(); - void enqueueWork(std::shared_ptr work, cudaStream_t stream); + void enqueueWork(c10::intrusive_ptr work, cudaStream_t stream); bool getGraphCaptureMode(); cudaStream_t getOperationStream(bool async_op); void ensureTensorContiguous(const at::Tensor& tensor); @@ -357,7 +357,7 @@ class TorchCommNCCL : public TorchCommBackend, // destruction, organized per graph using capture ID std::unordered_map< unsigned long long, - std::vector>> + std::vector>> graph_capture_work_refs_; std::mutex graph_capture_work_mutex_; diff --git a/comms/torchcomms/nccl/TorchCommNCCLUtils.cpp b/comms/torchcomms/nccl/TorchCommNCCLUtils.cpp index 36356127..a2e894cc 100644 --- a/comms/torchcomms/nccl/TorchCommNCCLUtils.cpp +++ b/comms/torchcomms/nccl/TorchCommNCCLUtils.cpp @@ -284,18 +284,18 @@ bool TorchCommNCCL::getGraphCaptureMode() { std::string(cuda_api_->getErrorString(err))); } -std::shared_ptr TorchCommNCCL::createWork( +c10::intrusive_ptr TorchCommNCCL::createWork( cudaStream_t stream, std::chrono::milliseconds timeout, const std::vector& inputTensors) { // Only create the work object without enqueuing it - auto work = std::make_shared( + auto work = c10::make_intrusive( shared_from_this(), stream, timeout, inputTensors, tracing_); return work; } void TorchCommNCCL::enqueueWork( - std::shared_ptr work, + c10::intrusive_ptr work, cudaStream_t stream) { // In graph capture mode, keep a reference to the work object to prevent // premature destruction until the graph gets destroyed, organized per graph diff --git a/comms/torchcomms/nccl/TorchWorkNCCL.hpp b/comms/torchcomms/nccl/TorchWorkNCCL.hpp index baf6af0f..65c5fdb3 100644 --- a/comms/torchcomms/nccl/TorchWorkNCCL.hpp +++ b/comms/torchcomms/nccl/TorchWorkNCCL.hpp @@ -101,12 +101,13 @@ class TorchWorkNCCLQueue { TorchWorkNCCL::WorkStatus garbageCollect(); // Finalize function can only be called from the main thread TorchWorkNCCL::WorkStatus finalize(); - void enqueueWork(std::shared_ptr work, cudaStream_t stream); + void enqueueWork(c10::intrusive_ptr work, cudaStream_t stream); private: TorchWorkNCCL::WorkStatus garbageCollectLocked(); - std::unordered_map>> - stream_work_queues_; + std:: + unordered_map>> + stream_work_queues_; std::mutex work_queues_mutex_; }; diff --git a/comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp b/comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp index 632bef19..60bde1c4 100644 --- a/comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp +++ b/comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp @@ -84,7 +84,7 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::finalize() { } void TorchWorkNCCLQueue::enqueueWork( - std::shared_ptr work, + c10::intrusive_ptr work, cudaStream_t stream) { // Add work to stream's queue after events have been recorded std::lock_guard lock(work_queues_mutex_); diff --git a/comms/torchcomms/ncclx/TorchCommNCCLX.cpp b/comms/torchcomms/ncclx/TorchCommNCCLX.cpp index bb13a5e0..e30b3588 100644 --- a/comms/torchcomms/ncclx/TorchCommNCCLX.cpp +++ b/comms/torchcomms/ncclx/TorchCommNCCLX.cpp @@ -349,7 +349,7 @@ static inline std::chrono::milliseconds getOperationTimeout( } // Point-to-Point Operations -std::shared_ptr TorchCommNCCLX::send( +c10::intrusive_ptr TorchCommNCCLX::send( const at::Tensor& tensor, int dst, bool async_op, @@ -389,7 +389,7 @@ std::shared_ptr TorchCommNCCLX::send( return work; } -std::shared_ptr TorchCommNCCLX::recv( +c10::intrusive_ptr TorchCommNCCLX::recv( at::Tensor& tensor, int src, bool async_op, @@ -430,7 +430,7 @@ std::shared_ptr TorchCommNCCLX::recv( } // Batch P2P Operations -std::shared_ptr TorchCommNCCLX::batch_op_issue( +c10::intrusive_ptr TorchCommNCCLX::batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options) { @@ -531,7 +531,7 @@ std::shared_ptr TorchCommNCCLX::batch_op_issue( } // Collective Operations -std::shared_ptr TorchCommNCCLX::broadcast( +c10::intrusive_ptr TorchCommNCCLX::broadcast( at::Tensor& tensor, int root, bool async_op, @@ -572,7 +572,7 @@ std::shared_ptr TorchCommNCCLX::broadcast( return work; } -std::shared_ptr TorchCommNCCLX::all_reduce( +c10::intrusive_ptr TorchCommNCCLX::all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, @@ -614,7 +614,7 @@ std::shared_ptr TorchCommNCCLX::all_reduce( return work; } -std::shared_ptr TorchCommNCCLX::reduce( +c10::intrusive_ptr TorchCommNCCLX::reduce( const at::Tensor& tensor, int root, ReduceOp op, @@ -662,7 +662,7 @@ std::shared_ptr TorchCommNCCLX::reduce( return work; } -std::shared_ptr TorchCommNCCLX::all_gather( +c10::intrusive_ptr TorchCommNCCLX::all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -719,7 +719,7 @@ std::shared_ptr TorchCommNCCLX::all_gather( return work; } -std::shared_ptr TorchCommNCCLX::all_gather_v( +c10::intrusive_ptr TorchCommNCCLX::all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -778,7 +778,7 @@ std::shared_ptr TorchCommNCCLX::all_gather_v( return work; } -std::shared_ptr TorchCommNCCLX::all_gather_single( +c10::intrusive_ptr TorchCommNCCLX::all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -822,7 +822,7 @@ std::shared_ptr TorchCommNCCLX::all_gather_single( return work; } -std::shared_ptr TorchCommNCCLX::reduce_scatter( +c10::intrusive_ptr TorchCommNCCLX::reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -897,7 +897,7 @@ std::shared_ptr TorchCommNCCLX::reduce_scatter( return work; } -std::shared_ptr TorchCommNCCLX::reduce_scatter_v( +c10::intrusive_ptr TorchCommNCCLX::reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -977,7 +977,7 @@ std::shared_ptr TorchCommNCCLX::reduce_scatter_v( return work; } -std::shared_ptr TorchCommNCCLX::reduce_scatter_single( +c10::intrusive_ptr TorchCommNCCLX::reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, @@ -1026,7 +1026,7 @@ std::shared_ptr TorchCommNCCLX::reduce_scatter_single( return work; } -std::shared_ptr TorchCommNCCLX::all_to_all_single( +c10::intrusive_ptr TorchCommNCCLX::all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -1079,7 +1079,7 @@ std::shared_ptr TorchCommNCCLX::all_to_all_single( return work; } -std::shared_ptr TorchCommNCCLX::all_to_all_v_single( +c10::intrusive_ptr TorchCommNCCLX::all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, @@ -1158,7 +1158,7 @@ std::shared_ptr TorchCommNCCLX::all_to_all_v_single( return work; } -std::shared_ptr TorchCommNCCLX::all_to_all( +c10::intrusive_ptr TorchCommNCCLX::all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, @@ -1227,7 +1227,7 @@ std::shared_ptr TorchCommNCCLX::all_to_all( return work; } -std::shared_ptr TorchCommNCCLX::barrier( +c10::intrusive_ptr TorchCommNCCLX::barrier( bool async_op, const BarrierOptions& options) { checkInitialized(); @@ -1265,7 +1265,7 @@ std::shared_ptr TorchCommNCCLX::barrier( return work; } -std::shared_ptr TorchCommNCCLX::scatter( +c10::intrusive_ptr TorchCommNCCLX::scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, @@ -1355,7 +1355,7 @@ std::shared_ptr TorchCommNCCLX::scatter( return work; } -std::shared_ptr TorchCommNCCLX::gather( +c10::intrusive_ptr TorchCommNCCLX::gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, diff --git a/comms/torchcomms/ncclx/TorchCommNCCLX.hpp b/comms/torchcomms/ncclx/TorchCommNCCLX.hpp index 0e9f7a33..72847d60 100644 --- a/comms/torchcomms/ncclx/TorchCommNCCLX.hpp +++ b/comms/torchcomms/ncclx/TorchCommNCCLX.hpp @@ -70,102 +70,102 @@ class TorchCommNCCLX : public TorchCommBackend, std::string_view getCommName() const override; // Point-to-Point Operations - std::shared_ptr send( + c10::intrusive_ptr send( const at::Tensor& tensor, int dst, bool async_op, const SendOptions& options = {}) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( at::Tensor& tensor, int src, bool async_op, const RecvOptions& options = {}) override; // Batch P2P Operations - std::shared_ptr batch_op_issue( + c10::intrusive_ptr batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options = {}) override; // Collective Operations - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( at::Tensor& tensor, int root, bool async_op, const BroadcastOptions& options = {}) override; - std::shared_ptr all_reduce( + c10::intrusive_ptr all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, const AllReduceOptions& options = {}) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( const at::Tensor& tensor, int root, ReduceOp op, bool async_op, const ReduceOptions& options = {}) override; - std::shared_ptr all_gather( + c10::intrusive_ptr all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_v( + c10::intrusive_ptr all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_single( + c10::intrusive_ptr all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllGatherSingleOptions& options = {}) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_v( + c10::intrusive_ptr reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_single( + c10::intrusive_ptr reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, bool async_op, const ReduceScatterSingleOptions& options = {}) override; - std::shared_ptr all_to_all_single( + c10::intrusive_ptr all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllToAllSingleOptions& options = {}) override; - std::shared_ptr all_to_all_v_single( + c10::intrusive_ptr all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, const std::vector& input_split_sizes, bool async_op, const AllToAllvSingleOptions& options = {}) override; - std::shared_ptr all_to_all( + c10::intrusive_ptr all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, const AllToAllOptions& options = {}) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( bool async_op, const BarrierOptions& options = {}) override; // Scatter and Gather Operations - std::shared_ptr scatter( + c10::intrusive_ptr scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, bool async_op, const ScatterOptions& options = {}) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, @@ -248,12 +248,12 @@ class TorchCommNCCLX : public TorchCommBackend, void deregister_address(const Address& addr); ncclDataType_t getNcclDataType(const at::Tensor& tensor); - std::shared_ptr createWork( + c10::intrusive_ptr createWork( cudaStream_t stream, std::chrono::milliseconds timeout, const std::vector& inputTensors = {}); - std::shared_ptr createWork( + c10::intrusive_ptr createWork( cudaStream_t stream, std::chrono::milliseconds timeout, const at::Tensor& inputTensor); @@ -318,7 +318,9 @@ class TorchCommNCCLX : public TorchCommBackend, void checkInitialized() const; void checkAndAbortIfTimedOutOrError(); void checkWorkQueue(); - void enqueueWork(std::shared_ptr work, cudaStream_t stream); + void enqueueWork( + c10::intrusive_ptr work, + cudaStream_t stream); bool getGraphCaptureMode(); cudaStream_t getOperationStream(bool async_op); void ensureTensorContiguous(const at::Tensor& tensor); @@ -379,7 +381,7 @@ class TorchCommNCCLX : public TorchCommBackend, // destruction, organized per graph using capture ID std::unordered_map< unsigned long long, - std::vector>> + std::vector>> graph_capture_work_refs_; std::mutex graph_capture_work_mutex_; diff --git a/comms/torchcomms/ncclx/TorchCommNCCLXUtils.cpp b/comms/torchcomms/ncclx/TorchCommNCCLXUtils.cpp index 0ebe5ea9..453cbb8d 100644 --- a/comms/torchcomms/ncclx/TorchCommNCCLXUtils.cpp +++ b/comms/torchcomms/ncclx/TorchCommNCCLXUtils.cpp @@ -273,28 +273,28 @@ bool TorchCommNCCLX::getGraphCaptureMode() { std::string(cuda_api_->getErrorString(err))); } -std::shared_ptr TorchCommNCCLX::createWork( +c10::intrusive_ptr TorchCommNCCLX::createWork( cudaStream_t stream, std::chrono::milliseconds timeout, const std::vector& inputTensors) { // Only create the work object without enqueuing it - auto work = std::make_shared( + auto work = c10::make_intrusive( shared_from_this(), stream, timeout, inputTensors); return work; } -std::shared_ptr TorchCommNCCLX::createWork( +c10::intrusive_ptr TorchCommNCCLX::createWork( cudaStream_t stream, std::chrono::milliseconds timeout, const at::Tensor& inputTensor) { // Only create the work object without enqueuing it - auto work = std::make_shared( + auto work = c10::make_intrusive( shared_from_this(), stream, timeout, inputTensor); return work; } void TorchCommNCCLX::enqueueWork( - std::shared_ptr work, + c10::intrusive_ptr work, cudaStream_t stream) { // In graph capture mode, keep a reference to the work object to prevent // premature destruction until the graph gets destroyed, organized per graph diff --git a/comms/torchcomms/ncclx/TorchCommWindowNCCLX.cpp b/comms/torchcomms/ncclx/TorchCommWindowNCCLX.cpp index d4c85292..64e94850 100644 --- a/comms/torchcomms/ncclx/TorchCommWindowNCCLX.cpp +++ b/comms/torchcomms/ncclx/TorchCommWindowNCCLX.cpp @@ -79,7 +79,7 @@ void TorchCommWindowNCCLX::allocate( << "[TorchCommWindowNCCLX]: NCCLX window allocation failed."; } -std::shared_ptr TorchCommWindowNCCLX::put( +c10::intrusive_ptr TorchCommWindowNCCLX::put( const at::Tensor& data, int dstRank, size_t targetDisp, @@ -150,7 +150,7 @@ at::Tensor TorchCommWindowNCCLX::getTensor( return t; } -std::shared_ptr TorchCommWindowNCCLX::signal( +c10::intrusive_ptr TorchCommWindowNCCLX::signal( size_t signalDisp, uint64_t signalVal, int dstRank, @@ -176,7 +176,7 @@ std::shared_ptr TorchCommWindowNCCLX::signal( return work; } -std::shared_ptr TorchCommWindowNCCLX::waitSignal( +c10::intrusive_ptr TorchCommWindowNCCLX::waitSignal( size_t signalDisp, uint64_t cmpVal, SignalCmpOp cmpOp, diff --git a/comms/torchcomms/ncclx/TorchCommWindowNCCLX.hpp b/comms/torchcomms/ncclx/TorchCommWindowNCCLX.hpp index c0f65484..fd642319 100644 --- a/comms/torchcomms/ncclx/TorchCommWindowNCCLX.hpp +++ b/comms/torchcomms/ncclx/TorchCommWindowNCCLX.hpp @@ -33,17 +33,17 @@ class TorchCommWindowNCCLX : public TorchCommWindow { TorchCommWindowNCCLX& operator=(TorchCommWindowNCCLX&& other) noexcept = delete; - std::shared_ptr put( + c10::intrusive_ptr put( const at::Tensor& data, int dstRank, size_t targetDisp, bool asyncOp) override; - std::shared_ptr signal( + c10::intrusive_ptr signal( size_t signalDisp, uint64_t signalVal, int dstRank, bool asyncOp) override; - std::shared_ptr waitSignal( + c10::intrusive_ptr waitSignal( size_t signalDisp, uint64_t cmpVal, SignalCmpOp cmpOp, diff --git a/comms/torchcomms/ncclx/TorchWorkNCCLX.hpp b/comms/torchcomms/ncclx/TorchWorkNCCLX.hpp index 1da092b1..5b5b0724 100644 --- a/comms/torchcomms/ncclx/TorchWorkNCCLX.hpp +++ b/comms/torchcomms/ncclx/TorchWorkNCCLX.hpp @@ -111,12 +111,16 @@ class TorchWorkNCCLXQueue { TorchWorkNCCLX::WorkStatus garbageCollect(); // Finalize function can only be called from the main thread TorchWorkNCCLX::WorkStatus finalize(); - void enqueueWork(std::shared_ptr work, cudaStream_t stream); + void enqueueWork( + c10::intrusive_ptr work, + cudaStream_t stream); private: TorchWorkNCCLX::WorkStatus garbageCollectLocked(); - std::unordered_map>> + std::unordered_map< + cudaStream_t, + std::queue>> stream_work_queues_; std::mutex work_queues_mutex_; diff --git a/comms/torchcomms/ncclx/TorchWorkNCCLXQueue.cpp b/comms/torchcomms/ncclx/TorchWorkNCCLXQueue.cpp index fde9e5a9..358fd5c0 100644 --- a/comms/torchcomms/ncclx/TorchWorkNCCLXQueue.cpp +++ b/comms/torchcomms/ncclx/TorchWorkNCCLXQueue.cpp @@ -85,7 +85,7 @@ TorchWorkNCCLX::WorkStatus TorchWorkNCCLXQueue::finalize() { } void TorchWorkNCCLXQueue::enqueueWork( - std::shared_ptr work, + c10::intrusive_ptr work, cudaStream_t stream) { // Add work to stream's queue after events have been recorded std::lock_guard lock(work_queues_mutex_); diff --git a/comms/torchcomms/ncclx/tests/integration/cpp/ProfilerTest.cpp b/comms/torchcomms/ncclx/tests/integration/cpp/ProfilerTest.cpp index 744bc74e..8fbab0f1 100644 --- a/comms/torchcomms/ncclx/tests/integration/cpp/ProfilerTest.cpp +++ b/comms/torchcomms/ncclx/tests/integration/cpp/ProfilerTest.cpp @@ -65,7 +65,7 @@ void ProfilerTest::sanityCheckProfilerMeta( } } -std::shared_ptr +c10::intrusive_ptr ProfilerTest::runAllCollectiveOperations() { auto options = at::TensorOptions().dtype(kTensorDtype).device(device_type_); diff --git a/comms/torchcomms/ncclx/tests/integration/cpp/ProfilerTest.hpp b/comms/torchcomms/ncclx/tests/integration/cpp/ProfilerTest.hpp index 008cf306..44fe892b 100644 --- a/comms/torchcomms/ncclx/tests/integration/cpp/ProfilerTest.hpp +++ b/comms/torchcomms/ncclx/tests/integration/cpp/ProfilerTest.hpp @@ -64,7 +64,7 @@ class ProfilerTest : public ::testing::Test { Json::Value& json_value, std::map>& events); - std::shared_ptr runAllCollectiveOperations(); + c10::intrusive_ptr runAllCollectiveOperations(); protected: void SetUp() override; diff --git a/comms/torchcomms/ncclx/tests/unit/cpp/TorchWorkNCCLXQueueTest.cpp b/comms/torchcomms/ncclx/tests/unit/cpp/TorchWorkNCCLXQueueTest.cpp index 563782bc..a3dad5d1 100644 --- a/comms/torchcomms/ncclx/tests/unit/cpp/TorchWorkNCCLXQueueTest.cpp +++ b/comms/torchcomms/ncclx/tests/unit/cpp/TorchWorkNCCLXQueueTest.cpp @@ -252,7 +252,7 @@ TEST_F(TorchWorkNCCLXQueueTest, EnqueueNullWorkDoesNotCrash) { // Test that enqueueing null work doesn't crash during enqueue // Note: We don't call garbageCollect after this because that would // cause a segfault when trying to call checkStatus() on null work - std::shared_ptr null_work = nullptr; + c10::intrusive_ptr null_work = nullptr; // This should not crash the queue during enqueue EXPECT_NO_THROW(queue_->enqueueWork(null_work, stream1_)); diff --git a/comms/torchcomms/rccl/TorchCommRCCL.cpp b/comms/torchcomms/rccl/TorchCommRCCL.cpp index 5e5dfdf9..9b1cec51 100644 --- a/comms/torchcomms/rccl/TorchCommRCCL.cpp +++ b/comms/torchcomms/rccl/TorchCommRCCL.cpp @@ -241,7 +241,7 @@ void TorchCommRCCL::finalize() { } // Clear the completed works queue - std::queue> empty; + std::queue> empty; std::swap(completed_works_, empty); // Clean up event pool @@ -333,7 +333,7 @@ static inline std::chrono::milliseconds getOperationTimeout( } // Point-to-Point Operations -std::shared_ptr TorchCommRCCL::send( +c10::intrusive_ptr TorchCommRCCL::send( const at::Tensor& tensor, int dst, bool async_op, @@ -372,7 +372,7 @@ std::shared_ptr TorchCommRCCL::send( return work; } -std::shared_ptr TorchCommRCCL::recv( +c10::intrusive_ptr TorchCommRCCL::recv( at::Tensor& tensor, int src, bool async_op, @@ -412,7 +412,7 @@ std::shared_ptr TorchCommRCCL::recv( } // Batch P2P Operations -std::shared_ptr TorchCommRCCL::batch_op_issue( +c10::intrusive_ptr TorchCommRCCL::batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options) { @@ -508,7 +508,7 @@ std::shared_ptr TorchCommRCCL::batch_op_issue( } // Collective Operations -std::shared_ptr TorchCommRCCL::broadcast( +c10::intrusive_ptr TorchCommRCCL::broadcast( at::Tensor& tensor, int root, bool async_op, @@ -548,7 +548,7 @@ std::shared_ptr TorchCommRCCL::broadcast( return work; } -std::shared_ptr TorchCommRCCL::all_reduce( +c10::intrusive_ptr TorchCommRCCL::all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, @@ -589,7 +589,7 @@ std::shared_ptr TorchCommRCCL::all_reduce( return work; } -std::shared_ptr TorchCommRCCL::reduce( +c10::intrusive_ptr TorchCommRCCL::reduce( const at::Tensor& tensor, int root, ReduceOp op, @@ -636,7 +636,7 @@ std::shared_ptr TorchCommRCCL::reduce( return work; } -std::shared_ptr TorchCommRCCL::all_gather( +c10::intrusive_ptr TorchCommRCCL::all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -693,7 +693,7 @@ std::shared_ptr TorchCommRCCL::all_gather( return work; } -std::shared_ptr TorchCommRCCL::all_gather_v( +c10::intrusive_ptr TorchCommRCCL::all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, @@ -701,7 +701,7 @@ std::shared_ptr TorchCommRCCL::all_gather_v( throw std::runtime_error("all_gather_v not implemented"); } -std::shared_ptr TorchCommRCCL::all_gather_single( +c10::intrusive_ptr TorchCommRCCL::all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -745,7 +745,7 @@ std::shared_ptr TorchCommRCCL::all_gather_single( return work; } -std::shared_ptr TorchCommRCCL::reduce_scatter( +c10::intrusive_ptr TorchCommRCCL::reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -821,7 +821,7 @@ std::shared_ptr TorchCommRCCL::reduce_scatter( return work; } -std::shared_ptr TorchCommRCCL::reduce_scatter_v( +c10::intrusive_ptr TorchCommRCCL::reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, @@ -830,7 +830,7 @@ std::shared_ptr TorchCommRCCL::reduce_scatter_v( throw std::runtime_error("reduce_scatter_v not implemented"); } -std::shared_ptr TorchCommRCCL::reduce_scatter_single( +c10::intrusive_ptr TorchCommRCCL::reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, @@ -879,7 +879,7 @@ std::shared_ptr TorchCommRCCL::reduce_scatter_single( return work; } -std::shared_ptr TorchCommRCCL::all_to_all_single( +c10::intrusive_ptr TorchCommRCCL::all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, @@ -932,7 +932,7 @@ std::shared_ptr TorchCommRCCL::all_to_all_single( return work; } -std::shared_ptr TorchCommRCCL::all_to_all_v_single( +c10::intrusive_ptr TorchCommRCCL::all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, @@ -1023,7 +1023,7 @@ std::shared_ptr TorchCommRCCL::all_to_all_v_single( return work; } -std::shared_ptr TorchCommRCCL::all_to_all( +c10::intrusive_ptr TorchCommRCCL::all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, @@ -1087,7 +1087,7 @@ std::shared_ptr TorchCommRCCL::all_to_all( return work; } -std::shared_ptr TorchCommRCCL::barrier( +c10::intrusive_ptr TorchCommRCCL::barrier( bool async_op, const BarrierOptions& options) { checkInitialized(); @@ -1124,7 +1124,7 @@ std::shared_ptr TorchCommRCCL::barrier( return work; } -std::shared_ptr TorchCommRCCL::scatter( +c10::intrusive_ptr TorchCommRCCL::scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, @@ -1214,7 +1214,7 @@ std::shared_ptr TorchCommRCCL::scatter( return work; } -std::shared_ptr TorchCommRCCL::gather( +c10::intrusive_ptr TorchCommRCCL::gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, diff --git a/comms/torchcomms/rccl/TorchCommRCCL.hpp b/comms/torchcomms/rccl/TorchCommRCCL.hpp index 2a9707e4..72024295 100644 --- a/comms/torchcomms/rccl/TorchCommRCCL.hpp +++ b/comms/torchcomms/rccl/TorchCommRCCL.hpp @@ -64,102 +64,102 @@ class TorchCommRCCL : public TorchCommBackend, int getSize() const override; // Point-to-Point Operations - std::shared_ptr send( + c10::intrusive_ptr send( const at::Tensor& tensor, int dst, bool async_op, const SendOptions& options = {}) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( at::Tensor& tensor, int src, bool async_op, const RecvOptions& options = {}) override; // Batch P2P Operations - std::shared_ptr batch_op_issue( + c10::intrusive_ptr batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options = {}) override; // Collective Operations - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( at::Tensor& tensor, int root, bool async_op, const BroadcastOptions& options = {}) override; - std::shared_ptr all_reduce( + c10::intrusive_ptr all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, const AllReduceOptions& options = {}) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( const at::Tensor& tensor, int root, ReduceOp op, bool async_op, const ReduceOptions& options = {}) override; - std::shared_ptr all_gather( + c10::intrusive_ptr all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_v( + c10::intrusive_ptr all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_single( + c10::intrusive_ptr all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllGatherSingleOptions& options = {}) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_v( + c10::intrusive_ptr reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_single( + c10::intrusive_ptr reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, bool async_op, const ReduceScatterSingleOptions& options = {}) override; - std::shared_ptr all_to_all_single( + c10::intrusive_ptr all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllToAllSingleOptions& options = {}) override; - std::shared_ptr all_to_all_v_single( + c10::intrusive_ptr all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, const std::vector& input_split_sizes, bool async_op, const AllToAllvSingleOptions& options = {}) override; - std::shared_ptr all_to_all( + c10::intrusive_ptr all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, const AllToAllOptions& options = {}) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( bool async_op, const BarrierOptions& options = {}) override; // Scatter and Gather Operations - std::shared_ptr scatter( + c10::intrusive_ptr scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, bool async_op, const ScatterOptions& options = {}) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, @@ -288,11 +288,11 @@ class TorchCommRCCL : public TorchCommBackend, void checkInitialized() const; void checkAndAbortIfTimedOutOrError(); void garbageCollectWorkQueues(); - std::shared_ptr createWork( + c10::intrusive_ptr createWork( hipStream_t stream, std::chrono::milliseconds timeout, const std::vector& inputTensors); - void enqueueWork(std::shared_ptr work, hipStream_t stream); + void enqueueWork(c10::intrusive_ptr work, hipStream_t stream); hipStream_t getOperationStream(bool async_op); void ensureTensorContiguous(const at::Tensor& tensor); @@ -330,9 +330,9 @@ class TorchCommRCCL : public TorchCommBackend, std::mutex event_pool_mutex_; // Work tracking per stream - std::unordered_map>> + std::unordered_map>> stream_work_queues_; - std::queue> completed_works_; + std::queue> completed_works_; std::mutex work_queues_mutex_; // Timeout monitoring diff --git a/comms/torchcomms/rccl/TorchCommRCCLUtils.cpp b/comms/torchcomms/rccl/TorchCommRCCLUtils.cpp index 74827462..5b724a28 100644 --- a/comms/torchcomms/rccl/TorchCommRCCLUtils.cpp +++ b/comms/torchcomms/rccl/TorchCommRCCLUtils.cpp @@ -245,7 +245,7 @@ void TorchCommRCCL::checkAndAbortIfTimedOutOrError() { // Create an empty queue and swap with the completed_works_ queue // This is more efficient than calling clear() as it deallocates memory - std::queue> empty; + std::queue> empty; std::swap(completed_works_, empty); // The old queue will be destroyed when this scope exits } @@ -262,18 +262,18 @@ void TorchCommRCCL::checkAndAbortIfTimedOutOrError() { } } -std::shared_ptr TorchCommRCCL::createWork( +c10::intrusive_ptr TorchCommRCCL::createWork( hipStream_t stream, std::chrono::milliseconds timeout, const std::vector& inputTensors) { // Only create the work object without enqueuing it - auto work = std::make_shared( + auto work = c10::make_intrusive( shared_from_this(), stream, timeout, inputTensors, tracing_); return work; } void TorchCommRCCL::enqueueWork( - std::shared_ptr work, + c10::intrusive_ptr work, hipStream_t stream) { // Add work to stream's queue after events have been recorded std::lock_guard lock(work_queues_mutex_); diff --git a/comms/torchcomms/tests/integration/cpp/BroadcastTest.cpp b/comms/torchcomms/tests/integration/cpp/BroadcastTest.cpp index 48b23a63..8d76ec15 100644 --- a/comms/torchcomms/tests/integration/cpp/BroadcastTest.cpp +++ b/comms/torchcomms/tests/integration/cpp/BroadcastTest.cpp @@ -130,7 +130,7 @@ void BroadcastTest::testBroadcastInputDeleted(int count, at::ScalarType dtype) { const int root_value = 42; // Create work object to hold the async operation - std::shared_ptr work; + c10::intrusive_ptr work; { // Create tensor in a limited scope diff --git a/comms/torchcomms/tests/integration/cpp/ReduceTest.cpp b/comms/torchcomms/tests/integration/cpp/ReduceTest.cpp index fd5e8cc5..5d23a485 100644 --- a/comms/torchcomms/tests/integration/cpp/ReduceTest.cpp +++ b/comms/torchcomms/tests/integration/cpp/ReduceTest.cpp @@ -141,7 +141,7 @@ void ReduceTest::testReduceInputDeleted( const int root_rank = 0; // Create work object to hold the async operation - std::shared_ptr work; + c10::intrusive_ptr work; { // Create input tensor in a limited scope diff --git a/comms/torchcomms/tests/integration/cpp/SendRecvTest.cpp b/comms/torchcomms/tests/integration/cpp/SendRecvTest.cpp index 7c68af42..707f6c39 100644 --- a/comms/torchcomms/tests/integration/cpp/SendRecvTest.cpp +++ b/comms/torchcomms/tests/integration/cpp/SendRecvTest.cpp @@ -35,8 +35,8 @@ void SendRecvTest::testSyncSendRecv(int count, at::ScalarType dtype) { // Alternate send/recv order based on rank to avoid deadlock // Even ranks send first, then receive // Odd ranks receive first, then send - std::shared_ptr send_work; - std::shared_ptr recv_work; + c10::intrusive_ptr send_work; + c10::intrusive_ptr recv_work; if (rank_ % 2 == 0) { // Even ranks: send first, then receive @@ -90,8 +90,8 @@ void SendRecvTest::testAsyncSendRecv(int count, at::ScalarType dtype) { // Alternate send/recv order based on rank to avoid deadlock // Even ranks send first, then receive // Odd ranks receive first, then send - std::shared_ptr send_work; - std::shared_ptr recv_work; + c10::intrusive_ptr send_work; + c10::intrusive_ptr recv_work; if (rank_ % 2 == 0) { // Even ranks: send first, then receive @@ -126,8 +126,8 @@ void SendRecvTest::testAsyncSendRecvEarlyReset( // Alternate send/recv order based on rank to avoid deadlock // Even ranks send first, then receive // Odd ranks receive first, then send - std::shared_ptr send_work; - std::shared_ptr recv_work; + c10::intrusive_ptr send_work; + c10::intrusive_ptr recv_work; if (rank_ % 2 == 0) { // Even ranks: send first, then receive @@ -165,8 +165,8 @@ void SendRecvTest::testSendRecvInputDeleted(int count, at::ScalarType dtype) { int recv_rank = (rank_ + num_ranks_ - 1) % num_ranks_; // Create work objects to hold the async operations - std::shared_ptr send_work; - std::shared_ptr recv_work; + c10::intrusive_ptr send_work; + c10::intrusive_ptr recv_work; { // Create send tensor in a limited scope diff --git a/comms/torchcomms/tests/unit/cpp/DummyTorchCommBackend.cpp b/comms/torchcomms/tests/unit/cpp/DummyTorchCommBackend.cpp index 2e39b4af..9c273dd3 100644 --- a/comms/torchcomms/tests/unit/cpp/DummyTorchCommBackend.cpp +++ b/comms/torchcomms/tests/unit/cpp/DummyTorchCommBackend.cpp @@ -26,7 +26,7 @@ class DummyTorchCommWindow : public TorchCommWindow { (void)signal_size; win_size_ = window_size; } - std::shared_ptr put( + c10::intrusive_ptr put( const at::Tensor& data, int dstRank, size_t targetDisp, @@ -35,7 +35,7 @@ class DummyTorchCommWindow : public TorchCommWindow { (void)dstRank; (void)targetDisp; (void)asyncOp; - return std::make_shared(); + return c10::make_intrusive(); } at::Tensor getTensor( int rank, @@ -48,7 +48,7 @@ class DummyTorchCommWindow : public TorchCommWindow { (void)storageOffset; return at::Tensor(); } - std::shared_ptr signal( + c10::intrusive_ptr signal( size_t signalDisp, uint64_t signalVal, int dstRank, @@ -57,9 +57,9 @@ class DummyTorchCommWindow : public TorchCommWindow { (void)signalVal; (void)dstRank; (void)asyncOp; - return std::make_shared(); + return c10::make_intrusive(); } - virtual std::shared_ptr waitSignal( + virtual c10::intrusive_ptr waitSignal( size_t signalDisp, uint64_t cmpVal, SignalCmpOp cmpOp, @@ -68,7 +68,7 @@ class DummyTorchCommWindow : public TorchCommWindow { (void)cmpVal; (void)cmpOp; (void)asyncOp; - return std::make_shared(); + return c10::make_intrusive(); } }; @@ -105,153 +105,153 @@ std::string_view DummyTorchCommBackend::getBackendName() const { return kBackendName; } -std::shared_ptr DummyTorchCommBackend::send( +c10::intrusive_ptr DummyTorchCommBackend::send( const at::Tensor& tensor, int dst, bool async_op, const SendOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::recv( +c10::intrusive_ptr DummyTorchCommBackend::recv( at::Tensor& tensor, int src, bool async_op, const RecvOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::batch_op_issue( +c10::intrusive_ptr DummyTorchCommBackend::batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::broadcast( +c10::intrusive_ptr DummyTorchCommBackend::broadcast( at::Tensor& tensor, int root, bool async_op, const BroadcastOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::all_reduce( +c10::intrusive_ptr DummyTorchCommBackend::all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, const AllReduceOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::reduce( +c10::intrusive_ptr DummyTorchCommBackend::reduce( const at::Tensor& tensor, int root, ReduceOp op, bool async_op, const ReduceOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::all_gather( +c10::intrusive_ptr DummyTorchCommBackend::all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::all_gather_v( +c10::intrusive_ptr DummyTorchCommBackend::all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::all_gather_single( +c10::intrusive_ptr DummyTorchCommBackend::all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllGatherSingleOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::reduce_scatter( +c10::intrusive_ptr DummyTorchCommBackend::reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::reduce_scatter_v( +c10::intrusive_ptr DummyTorchCommBackend::reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::reduce_scatter_single( +c10::intrusive_ptr DummyTorchCommBackend::reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, bool async_op, const ReduceScatterSingleOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::all_to_all_single( +c10::intrusive_ptr DummyTorchCommBackend::all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllToAllSingleOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::all_to_all_v_single( +c10::intrusive_ptr DummyTorchCommBackend::all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, const std::vector& input_split_sizes, bool async_op, const AllToAllvSingleOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::all_to_all( +c10::intrusive_ptr DummyTorchCommBackend::all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, const AllToAllOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::barrier( +c10::intrusive_ptr DummyTorchCommBackend::barrier( bool async_op, const BarrierOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::scatter( +c10::intrusive_ptr DummyTorchCommBackend::scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, bool async_op, const ScatterOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr DummyTorchCommBackend::gather( +c10::intrusive_ptr DummyTorchCommBackend::gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root, bool async_op, const GatherOptions& options) { - return std::make_shared(); + return c10::make_intrusive(); } std::shared_ptr DummyTorchCommBackend::window_allocate( diff --git a/comms/torchcomms/tests/unit/cpp/DummyTorchCommBackend.hpp b/comms/torchcomms/tests/unit/cpp/DummyTorchCommBackend.hpp index 44f91cd5..ff3db323 100644 --- a/comms/torchcomms/tests/unit/cpp/DummyTorchCommBackend.hpp +++ b/comms/torchcomms/tests/unit/cpp/DummyTorchCommBackend.hpp @@ -28,101 +28,101 @@ class DummyTorchCommBackend : public TorchCommBackend { std::string_view getBackendName() const override; // Point-to-Point Operations - std::shared_ptr send( + c10::intrusive_ptr send( const at::Tensor& tensor, int dst, bool async_op, const SendOptions& options = {}) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( at::Tensor& tensor, int src, bool async_op, const RecvOptions& options = {}) override; - std::shared_ptr batch_op_issue( + c10::intrusive_ptr batch_op_issue( const std::vector& ops, bool async_op, const BatchP2POptions& options = {}) override; // Collective Operations - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( at::Tensor& tensor, int root, bool async_op, const BroadcastOptions& options = {}) override; - std::shared_ptr all_reduce( + c10::intrusive_ptr all_reduce( at::Tensor& tensor, ReduceOp op, bool async_op, const AllReduceOptions& options = {}) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( const at::Tensor& tensor, int root, ReduceOp op, bool async_op, const ReduceOptions& options = {}) override; - std::shared_ptr all_gather( + c10::intrusive_ptr all_gather( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_v( + c10::intrusive_ptr all_gather_v( const std::vector& tensor_list, const at::Tensor& tensor, bool async_op, const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_single( + c10::intrusive_ptr all_gather_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllGatherSingleOptions& options = {}) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_v( + c10::intrusive_ptr reduce_scatter_v( at::Tensor& output, const std::vector& input_list, ReduceOp op, bool async_op, const ReduceScatterOptions& options = {}) override; - std::shared_ptr reduce_scatter_single( + c10::intrusive_ptr reduce_scatter_single( at::Tensor& output, const at::Tensor& input, ReduceOp op, bool async_op, const ReduceScatterSingleOptions& options = {}) override; - std::shared_ptr all_to_all_single( + c10::intrusive_ptr all_to_all_single( at::Tensor& output, const at::Tensor& input, bool async_op, const AllToAllSingleOptions& options = {}) override; - std::shared_ptr all_to_all_v_single( + c10::intrusive_ptr all_to_all_v_single( at::Tensor& output, const at::Tensor& input, const std::vector& output_split_sizes, const std::vector& input_split_sizes, bool async_op, const AllToAllvSingleOptions& options = {}) override; - std::shared_ptr all_to_all( + c10::intrusive_ptr all_to_all( const std::vector& output_tensor_list, const std::vector& input_tensor_list, bool async_op, const AllToAllOptions& options = {}) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( bool async_op, const BarrierOptions& options = {}) override; // Scatter and Gather Operations - std::shared_ptr scatter( + c10::intrusive_ptr scatter( at::Tensor& output_tensor, const std::vector& input_tensor_list, int root, bool async_op, const ScatterOptions& options = {}) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( const std::vector& output_tensor_list, const at::Tensor& input_tensor, int root,