Skip to content

Intro async flag and use current stream avoid stream sync #1546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

Chao1Han
Copy link
Contributor

@Chao1Han Chao1Han commented Apr 7, 2025

Refer pytorch/pytorch#147820 pytorch/pytorch#150398
To launch kernels on the current stream and reduce the CPU overhead introduced by recordStream, an async option is introduced.

For example, in an allreduce operation between two ranks:

  • rank0 corresponds to device0, using the current device's stream0 to create the communicator and preserving stream0.

When async = true:

  • Both rank0 and rank1 perform the collective using stream0, which is associated with the communicator.
  • To prevent potential reads by stream0 from unready tensors (e.g., from rank1), synchronization with the current stream is required.
  • After the collective completes, to prevent premature release of the input tensors, recordStream must be used for stream tracking, or the tensors need to be temporarily stored (e.g., in reduce_scatter or all2all).

When async = false:

  • Both rank0 and rank1 use their respective current streams for collectives (i.e., rank0 uses stream0, rank1 uses stream1).
  • In this case, the collective op handles synchronization implicitly.

Previously, we defaulted to async = true. Now, the async option is explicitly introduced and set to false by default, leveraging the current stream to avoid the overhead of stream synchronization.

@Chao1Han Chao1Han force-pushed the xccl/record_stream branch from 19216e1 to 5fe91cb Compare April 7, 2025 07:25
@Chao1Han Chao1Han changed the title [wip] Xccl/record stream Intro async flag and use current stream avoid stream sync Apr 9, 2025
@Chao1Han
Copy link
Contributor Author

Chao1Han commented Apr 9, 2025

image

}

bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
synchronize();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need sync if compute stream is used for communication?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed, use current stream means async=false. line 632 will return null rather then work and frontend will check if need call wait()
https://github.com/pytorch/pytorch/blob/4273e5d15cfcb282b2795684874ea439d8620999/torch/distributed/distributed_c10d.py#L2882-L2887


// asyncOp=false will always use current stream; getStrem will return current
// stream
c10::Stream stream = asyncOp

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some special case like FSDP2, same communicator will have different current stream on each device. Then how to make sure to return correct compute stream for communication?

cclstream =
std::make_unique<ccl::stream>(xcclStreamsMap_.at(StreamKey).second);
} catch (...) {
LOG(WARNING) << "Current stream id changed, create new ccl stream";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it to be a warning? I think info should be more suitable. Warning usually means something not so safe and may cause unpredictable result.

auto stream = xcclStreamsMap_.at(key).first;
auto cclstream = xcclStreamsMap_.at(key).second;
auto StreamKey = asyncOp ? key
: key + "_" +

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you put code in 63823e4#diff-29271b6f1608f7ad940c9cd242ce24dcc68bba932348e39dc0b524604cc78c6aR568, then you don't need construct StreamKey twice.

auto StreamKey = asyncOp ? key
: key + "_" +
std::to_string(at::xpu::getCurrentXPUStream(device.index()).id());
auto stream = asyncOp ? xcclStreamsMap_.at(StreamKey).first

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto stream = asyncOp ? xcclStreamsMap_.at(StreamKey).first
if (asyncOp) {
stream = xcclStreamsMap_.at(key).first;
cclstream = xcclStreamsMap_.at(key).second;
syncStream(device, xcclEventsMap_[key], stream);
} else {
current_stream = at::xpu::getCurrentXPUStream(device.index())
streamkey = stream + current_stream .id()
if (xcclStreamsMap_.find(streamkey) != xcclStreamsMap_.end()) {
stream = xcclStreamsMap_.at(streamkey ).first;
cclstream = xcclStreamsMap_.at(streamkey).second;
} else {
// update xcclStreamsMap_ with current stream key
cclstream = std::make_unique<ccl::stream>(ccl::create_stream(current_stream.queue()));
std::lock_guard<std::mutex> lock(mutex_);
xcclStreamsMap_.emplace(
StreamKey, std::make_pair(at::xpu::XPUStream(current_stream), *cclstream));
}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants