Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comms/torchcomms/BackendWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ std::vector<uint64_t> toVecUint64(const std::vector<int64_t>& vec) {

} // namespace

WorkWrapper::WorkWrapper(std::shared_ptr<TorchWork> work)
WorkWrapper::WorkWrapper(c10::intrusive_ptr<TorchWork> work)
: work_(std::move(work)) {}

bool WorkWrapper::isCompleted() {
Expand Down
4 changes: 2 additions & 2 deletions comms/torchcomms/BackendWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace comms {

class WorkWrapper : public c10d::Work {
public:
explicit WorkWrapper(std::shared_ptr<TorchWork> work);
explicit WorkWrapper(c10::intrusive_ptr<TorchWork> work);
~WorkWrapper() override = default;

bool isCompleted() override;
Expand All @@ -23,7 +23,7 @@ class WorkWrapper : public c10d::Work {
std::vector<at::Tensor> result() override;

private:
std::shared_ptr<TorchWork> work_;
c10::intrusive_ptr<TorchWork> work_;
};

using c10d::kUnsetTimeout;
Expand Down
36 changes: 18 additions & 18 deletions comms/torchcomms/TorchComm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ std::string_view TorchComm::getCommName() const {
}

// Point-to-Point Operations
std::shared_ptr<TorchWork> TorchComm::send(
c10::intrusive_ptr<TorchWork> TorchComm::send(
const at::Tensor& tensor,
int dst,
bool async_op,
const SendOptions& options) {
return impl_->send(tensor, dst, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::recv(
c10::intrusive_ptr<TorchWork> TorchComm::recv(
at::Tensor& tensor,
int src,
bool async_op,
Expand All @@ -56,23 +56,23 @@ std::shared_ptr<TorchWork> TorchComm::recv(
}

// Collective Operations
std::shared_ptr<TorchWork> TorchComm::broadcast(
c10::intrusive_ptr<TorchWork> TorchComm::broadcast(
at::Tensor& tensor,
int root,
bool async_op,
const BroadcastOptions& options) {
return impl_->broadcast(tensor, root, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::all_reduce(
c10::intrusive_ptr<TorchWork> TorchComm::all_reduce(
at::Tensor& tensor,
ReduceOp op,
bool async_op,
const AllReduceOptions& options) {
return impl_->all_reduce(tensor, op, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::reduce(
c10::intrusive_ptr<TorchWork> TorchComm::reduce(
const at::Tensor& tensor,
int root,
ReduceOp op,
Expand All @@ -81,31 +81,31 @@ std::shared_ptr<TorchWork> TorchComm::reduce(
return impl_->reduce(tensor, root, op, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::all_gather(
c10::intrusive_ptr<TorchWork> TorchComm::all_gather(
const std::vector<at::Tensor>& tensor_list,
const at::Tensor& tensor,
bool async_op,
const AllGatherOptions& options) {
return impl_->all_gather(tensor_list, tensor, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::all_gather_v(
c10::intrusive_ptr<TorchWork> TorchComm::all_gather_v(
const std::vector<at::Tensor>& tensor_list,
const at::Tensor& tensor,
bool async_op,
const AllGatherOptions& options) {
return impl_->all_gather_v(tensor_list, tensor, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::all_gather_single(
c10::intrusive_ptr<TorchWork> TorchComm::all_gather_single(
at::Tensor& output,
const at::Tensor& input,
bool async_op,
const AllGatherSingleOptions& options) {
return impl_->all_gather_single(output, input, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::reduce_scatter(
c10::intrusive_ptr<TorchWork> TorchComm::reduce_scatter(
at::Tensor& output,
const std::vector<at::Tensor>& input_list,
ReduceOp op,
Expand All @@ -114,7 +114,7 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter(
return impl_->reduce_scatter(output, input_list, op, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::reduce_scatter_v(
c10::intrusive_ptr<TorchWork> TorchComm::reduce_scatter_v(
at::Tensor& output,
const std::vector<at::Tensor>& input_list,
ReduceOp op,
Expand All @@ -123,7 +123,7 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter_v(
return impl_->reduce_scatter_v(output, input_list, op, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::reduce_scatter_single(
c10::intrusive_ptr<TorchWork> TorchComm::reduce_scatter_single(
at::Tensor& output,
const at::Tensor& input,
ReduceOp op,
Expand All @@ -132,15 +132,15 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter_single(
return impl_->reduce_scatter_single(output, input, op, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::all_to_all_single(
c10::intrusive_ptr<TorchWork> TorchComm::all_to_all_single(
at::Tensor& output,
const at::Tensor& input,
bool async_op,
const AllToAllSingleOptions& options) {
return impl_->all_to_all_single(output, input, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::all_to_all_v_single(
c10::intrusive_ptr<TorchWork> TorchComm::all_to_all_v_single(
at::Tensor& output,
const at::Tensor& input,
const std::vector<uint64_t>& output_split_sizes,
Expand All @@ -151,7 +151,7 @@ std::shared_ptr<TorchWork> TorchComm::all_to_all_v_single(
output, input, output_split_sizes, input_split_sizes, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::all_to_all(
c10::intrusive_ptr<TorchWork> TorchComm::all_to_all(
const std::vector<at::Tensor>& output_tensor_list,
const std::vector<at::Tensor>& input_tensor_list,
bool async_op,
Expand All @@ -160,14 +160,14 @@ std::shared_ptr<TorchWork> TorchComm::all_to_all(
output_tensor_list, input_tensor_list, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::barrier(
c10::intrusive_ptr<TorchWork> TorchComm::barrier(
bool async_op,
const BarrierOptions& options) {
return impl_->barrier(async_op, options);
}

// Scatter and Gather Operations
std::shared_ptr<TorchWork> TorchComm::scatter(
c10::intrusive_ptr<TorchWork> TorchComm::scatter(
at::Tensor& output_tensor,
const std::vector<at::Tensor>& input_tensor_list,
int root,
Expand All @@ -177,7 +177,7 @@ std::shared_ptr<TorchWork> TorchComm::scatter(
output_tensor, input_tensor_list, root, async_op, options);
}

std::shared_ptr<TorchWork> TorchComm::gather(
c10::intrusive_ptr<TorchWork> TorchComm::gather(
const std::vector<at::Tensor>& output_tensor_list,
const at::Tensor& input_tensor,
int root,
Expand Down Expand Up @@ -239,7 +239,7 @@ void BatchSendRecv::recv(at::Tensor& tensor, int src) {
ops.push_back(op);
}

std::shared_ptr<TorchWork> BatchSendRecv::issue(
c10::intrusive_ptr<TorchWork> BatchSendRecv::issue(
bool async_op,
const BatchP2POptions& options) {
return parent_->getBackendImpl()->batch_op_issue(ops, async_op, options);
Expand Down
34 changes: 17 additions & 17 deletions comms/torchcomms/TorchComm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,96 +32,96 @@ class TorchComm {
std::string_view getCommName() const;

// Point-to-Point Operations
std::shared_ptr<TorchWork> send(
c10::intrusive_ptr<TorchWork> send(
const at::Tensor& tensor,
int dst,
bool async_op,
const SendOptions& options = {});
std::shared_ptr<TorchWork> recv(
c10::intrusive_ptr<TorchWork> recv(
at::Tensor& tensor,
int src,
bool async_op,
const RecvOptions& options = {});

// Collective Operations
std::shared_ptr<TorchWork> broadcast(
c10::intrusive_ptr<TorchWork> broadcast(
at::Tensor& tensor,
int root,
bool async_op,
const BroadcastOptions& options = {});
std::shared_ptr<TorchWork> all_reduce(
c10::intrusive_ptr<TorchWork> all_reduce(
at::Tensor& tensor,
ReduceOp op,
bool async_op,
const AllReduceOptions& options = {});
std::shared_ptr<TorchWork> reduce(
c10::intrusive_ptr<TorchWork> reduce(
const at::Tensor& tensor,
int root,
ReduceOp op,
bool async_op,
const ReduceOptions& options = {});
std::shared_ptr<TorchWork> all_gather(
c10::intrusive_ptr<TorchWork> all_gather(
const std::vector<at::Tensor>& tensor_list,
const at::Tensor& tensor,
bool async_op,
const AllGatherOptions& options = {});
std::shared_ptr<TorchWork> all_gather_v(
c10::intrusive_ptr<TorchWork> all_gather_v(
const std::vector<at::Tensor>& tensor_list,
const at::Tensor& tensor,
bool async_op,
const AllGatherOptions& options = {});
std::shared_ptr<TorchWork> all_gather_single(
c10::intrusive_ptr<TorchWork> all_gather_single(
at::Tensor& output,
const at::Tensor& input,
bool async_op,
const AllGatherSingleOptions& options = {});
std::shared_ptr<TorchWork> reduce_scatter(
c10::intrusive_ptr<TorchWork> reduce_scatter(
at::Tensor& output,
const std::vector<at::Tensor>& input_list,
ReduceOp op,
bool async_op,
const ReduceScatterOptions& options = {});
std::shared_ptr<TorchWork> reduce_scatter_v(
c10::intrusive_ptr<TorchWork> reduce_scatter_v(
at::Tensor& output,
const std::vector<at::Tensor>& input_list,
ReduceOp op,
bool async_op,
const ReduceScatterOptions& options = {});
std::shared_ptr<TorchWork> reduce_scatter_single(
c10::intrusive_ptr<TorchWork> reduce_scatter_single(
at::Tensor& output,
const at::Tensor& input,
ReduceOp op,
bool async_op,
const ReduceScatterSingleOptions& options = {});
std::shared_ptr<TorchWork> all_to_all_single(
c10::intrusive_ptr<TorchWork> all_to_all_single(
at::Tensor& output,
const at::Tensor& input,
bool async_op,
const AllToAllSingleOptions& options = {});
std::shared_ptr<TorchWork> all_to_all_v_single(
c10::intrusive_ptr<TorchWork> all_to_all_v_single(
at::Tensor& output,
const at::Tensor& input,
const std::vector<uint64_t>& output_split_sizes,
const std::vector<uint64_t>& input_split_sizes,
bool async_op,
const AllToAllvSingleOptions& options = {});
std::shared_ptr<TorchWork> all_to_all(
c10::intrusive_ptr<TorchWork> all_to_all(
const std::vector<at::Tensor>& output_tensor_list,
const std::vector<at::Tensor>& input_tensor_list,
bool async_op,
const AllToAllOptions& options = {});
std::shared_ptr<TorchWork> barrier(
c10::intrusive_ptr<TorchWork> barrier(
bool async_op,
const BarrierOptions& options = {});

// Scatter and Gather Operations
std::shared_ptr<TorchWork> scatter(
c10::intrusive_ptr<TorchWork> scatter(
at::Tensor& output_tensor,
const std::vector<at::Tensor>& input_tensor_list,
int root,
bool async_op,
const ScatterOptions& options = {});
std::shared_ptr<TorchWork> gather(
c10::intrusive_ptr<TorchWork> gather(
const std::vector<at::Tensor>& output_tensor_list,
const at::Tensor& input_tensor,
int root,
Expand Down
Loading
Loading