-
Notifications
You must be signed in to change notification settings - Fork 37
Intro async flag and use current stream avoid stream sync #1546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
19216e1
to
5fe91cb
Compare
} | ||
|
||
bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { | ||
synchronize(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need sync if compute stream is used for communication?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed, use current stream means async=false. line 632 will return null rather then work
and frontend will check if need call wait()
https://github.com/pytorch/pytorch/blob/4273e5d15cfcb282b2795684874ea439d8620999/torch/distributed/distributed_c10d.py#L2882-L2887
src/xccl/ProcessGroupXCCL.cpp
Outdated
|
||
// asyncOp=false will always use current stream; getStrem will return current | ||
// stream | ||
c10::Stream stream = asyncOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some special case like FSDP2, same communicator will have different current stream
on each device. Then how to make sure to return correct compute stream for communication?
src/xccl/ProcessGroupXCCL.cpp
Outdated
cclstream = | ||
std::make_unique<ccl::stream>(xcclStreamsMap_.at(StreamKey).second); | ||
} catch (...) { | ||
LOG(WARNING) << "Current stream id changed, create new ccl stream"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it to be a warning? I think info
should be more suitable. Warning usually means something not so safe and may cause unpredictable result.
src/xccl/ProcessGroupXCCL.cpp
Outdated
auto stream = xcclStreamsMap_.at(key).first; | ||
auto cclstream = xcclStreamsMap_.at(key).second; | ||
auto StreamKey = asyncOp ? key | ||
: key + "_" + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should you put code in 63823e4#diff-29271b6f1608f7ad940c9cd242ce24dcc68bba932348e39dc0b524604cc78c6aR568, then you don't need construct StreamKey
twice.
src/xccl/ProcessGroupXCCL.cpp
Outdated
auto StreamKey = asyncOp ? key | ||
: key + "_" + | ||
std::to_string(at::xpu::getCurrentXPUStream(device.index()).id()); | ||
auto stream = asyncOp ? xcclStreamsMap_.at(StreamKey).first |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto stream = asyncOp ? xcclStreamsMap_.at(StreamKey).first | |
if (asyncOp) { | |
stream = xcclStreamsMap_.at(key).first; | |
cclstream = xcclStreamsMap_.at(key).second; | |
syncStream(device, xcclEventsMap_[key], stream); | |
} else { | |
current_stream = at::xpu::getCurrentXPUStream(device.index()) | |
streamkey = stream + current_stream .id() | |
if (xcclStreamsMap_.find(streamkey) != xcclStreamsMap_.end()) { | |
stream = xcclStreamsMap_.at(streamkey ).first; | |
cclstream = xcclStreamsMap_.at(streamkey).second; | |
} else { | |
// update xcclStreamsMap_ with current stream key | |
cclstream = std::make_unique<ccl::stream>(ccl::create_stream(current_stream.queue())); | |
std::lock_guard<std::mutex> lock(mutex_); | |
xcclStreamsMap_.emplace( | |
StreamKey, std::make_pair(at::xpu::XPUStream(current_stream), *cclstream)); | |
} | |
} | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change done
Refer pytorch/pytorch#147820 pytorch/pytorch#150398
To launch kernels on the current stream and reduce the CPU overhead introduced by
recordStream
, anasync
option is introduced.For example, in an
allreduce
operation between two ranks:rank0
corresponds todevice0
, using the current device'sstream0
to create the communicator and preservingstream0
.When
async = true
:rank0
andrank1
perform the collective usingstream0
, which is associated with the communicator.stream0
from unready tensors (e.g., fromrank1
), synchronization with the current stream is required.recordStream
must be used for stream tracking, or the tensors need to be temporarily stored (e.g., inreduce_scatter
orall2all
).When
async = false
:rank0
andrank1
use their respective current streams for collectives (i.e.,rank0
usesstream0
,rank1
usesstream1
).Previously, we defaulted to
async = true
. Now, theasync
option is explicitly introduced and set tofalse
by default, leveraging the current stream to avoid the overhead of stream synchronization.