@@ -416,7 +416,6 @@ void ProcessGroupXCCL::groupEnd() {
416416 --xcclActiveGroupCounter_;
417417}
418418
419- // TODO: wait p2p enable
420419static constexpr int CoalActive = 0x01 , CoalColl = 0x02 , CoalP2P = 0x04 ;
421420void ProcessGroupXCCL::startCoalescing () {
422421 coalescedDevice_.set_index (-1 );
@@ -525,14 +524,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
525524 return work;
526525}
527526
528- template <typename Fn, typename PreProcess, typename PostProcess >
527+ template <typename Fn>
529528c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint (
530529 at::Tensor& tensor,
531530 Fn fn,
532531 int peer,
533- OpType opType,
534- PreProcess pre ,
535- PostProcess post ) {
532+ OpType opType) {
536533 using traits = function_traits<Fn>;
537534 using attr_t = typename traits::template arg<1 >::type;
538535 attr_t attr = ccl::create_operation_attr<attr_t >();
@@ -576,40 +573,36 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
576573 auto stream = xcclStreams_.at (key);
577574 syncStream (device, xcclEvents_[key], stream);
578575
579- c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
580576 if (!coalescing_state_) {
577+ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
581578 work = initWork (device, rank_, opType);
582579 work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
583580 work->outputs_ ->push_back (tensor);
584- }
585581
586- at::xpu::OptionalXPUGuard gpuGuard (device);
582+ at::xpu::OptionalXPUGuard gpuGuard (device);
587583
588- if (!coalescing_state_) {
589- pre (stream, work);
590- }
591-
592- c10::xpu::XPUCachingAllocator::recordStream (
584+ c10::xpu::XPUCachingAllocator::recordStream (
593585 tensor.storage ().data_ptr (), stream);
594586
595- fn (tensor, attr, *comm, stream, p2pTargetRank);
596-
597- if (!coalescing_state_) {
598- post (stream);
587+ fn (tensor, attr, *comm, stream, p2pTargetRank);
599588
600589 work->xcclEndEvent_ ->record (stream);
601590 work->blockingWait_ = blockingWait_;
602-
603- {
604- std::vector<c10::Stream> streams = {stream.unwrap ()};
605- c10::MultiStreamGuard streamGuard (streams);
606- std::vector<at::Device> devices{device};
607- work->future_ = c10::make_intrusive<at::ivalue::Future>(
608- c10::ListType::create (c10::TensorType::get ()), devices);
609- work->future_ ->markCompleted (at::IValue (*work->outputs_ ));
610- }
591+ std::vector<c10::Stream> streams = {stream.unwrap ()};
592+ c10::MultiStreamGuard streamGuard (streams);
593+ std::vector<at::Device> devices{device};
594+ work->future_ = c10::make_intrusive<at::ivalue::Future>(
595+ c10::ListType::create (c10::TensorType::get ()), devices);
596+ work->future_ ->markCompleted (at::IValue (*work->outputs_ ));
611597 return work;
612598 } else {
599+ at::xpu::OptionalXPUGuard gpuGuard (device);
600+
601+ c10::xpu::XPUCachingAllocator::recordStream (
602+ tensor.storage ().data_ptr (), stream);
603+
604+ fn (tensor, attr, *comm, stream, p2pTargetRank);
605+
613606 return nullptr ;
614607 }
615608}
0 commit comments