Skip to content

Commit a3b2b0a

Browse files
committed
refine code
1 parent 69c22f9 commit a3b2b0a

File tree

2 files changed

+20
-45
lines changed

2 files changed

+20
-45
lines changed

torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,6 @@ void ProcessGroupXCCL::groupEnd() {
416416
--xcclActiveGroupCounter_;
417417
}
418418

419-
// TODO: wait p2p enable
420419
static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;
421420
void ProcessGroupXCCL::startCoalescing() {
422421
coalescedDevice_.set_index(-1);
@@ -525,14 +524,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
525524
return work;
526525
}
527526

528-
template <typename Fn, typename PreProcess, typename PostProcess>
527+
template <typename Fn>
529528
c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
530529
at::Tensor& tensor,
531530
Fn fn,
532531
int peer,
533-
OpType opType,
534-
PreProcess pre,
535-
PostProcess post) {
532+
OpType opType) {
536533
using traits = function_traits<Fn>;
537534
using attr_t = typename traits::template arg<1>::type;
538535
attr_t attr = ccl::create_operation_attr<attr_t>();
@@ -576,40 +573,36 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
576573
auto stream = xcclStreams_.at(key);
577574
syncStream(device, xcclEvents_[key], stream);
578575

579-
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
580576
if (!coalescing_state_) {
577+
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
581578
work = initWork(device, rank_, opType);
582579
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
583580
work->outputs_->push_back(tensor);
584-
}
585581

586-
at::xpu::OptionalXPUGuard gpuGuard(device);
582+
at::xpu::OptionalXPUGuard gpuGuard(device);
587583

588-
if (!coalescing_state_) {
589-
pre(stream, work);
590-
}
591-
592-
c10::xpu::XPUCachingAllocator::recordStream(
584+
c10::xpu::XPUCachingAllocator::recordStream(
593585
tensor.storage().data_ptr(), stream);
594586

595-
fn(tensor, attr, *comm, stream, p2pTargetRank);
596-
597-
if (!coalescing_state_) {
598-
post(stream);
587+
fn(tensor, attr, *comm, stream, p2pTargetRank);
599588

600589
work->xcclEndEvent_->record(stream);
601590
work->blockingWait_ = blockingWait_;
602-
603-
{
604-
std::vector<c10::Stream> streams = {stream.unwrap()};
605-
c10::MultiStreamGuard streamGuard(streams);
606-
std::vector<at::Device> devices{device};
607-
work->future_ = c10::make_intrusive<at::ivalue::Future>(
608-
c10::ListType::create(c10::TensorType::get()), devices);
609-
work->future_->markCompleted(at::IValue(*work->outputs_));
610-
}
591+
std::vector<c10::Stream> streams = {stream.unwrap()};
592+
c10::MultiStreamGuard streamGuard(streams);
593+
std::vector<at::Device> devices{device};
594+
work->future_ = c10::make_intrusive<at::ivalue::Future>(
595+
c10::ListType::create(c10::TensorType::get()), devices);
596+
work->future_->markCompleted(at::IValue(*work->outputs_));
611597
return work;
612598
} else {
599+
at::xpu::OptionalXPUGuard gpuGuard(device);
600+
601+
c10::xpu::XPUCachingAllocator::recordStream(
602+
tensor.storage().data_ptr(), stream);
603+
604+
fn(tensor, attr, *comm, stream, p2pTargetRank);
605+
613606
return nullptr;
614607
}
615608
}

torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -207,25 +207,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
207207
at::Tensor& tensor,
208208
Fn fn,
209209
int peer,
210-
OpType opType) {
211-
return pointToPoint(
212-
tensor,
213-
fn,
214-
peer,
215-
opType,
216-
[](at::xpu::XPUStream&, c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>&) {
217-
},
218-
[](at::xpu::XPUStream&) {});
219-
}
220-
221-
template <typename Fn, typename PreProcess, typename PostProcess>
222-
c10::intrusive_ptr<Work> pointToPoint(
223-
at::Tensor& tensor,
224-
Fn fn,
225-
int peer,
226-
OpType opType,
227-
PreProcess pre,
228-
PostProcess post);
210+
OpType opType);
229211

230212
c10::intrusive_ptr<Work> allreduce_impl(
231213
at::Tensor& tensor,

0 commit comments

Comments
 (0)