-
Notifications
You must be signed in to change notification settings - Fork 511
[JAX][Draft] Async issuing D2H memcpy for grouped_gemm group_sizes array #2213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
21ea9a3
b0de623
63d2832
25a15cd
90ac202
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,12 +280,71 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, | |
.Attr<JAXX_Collective_Op>("collective_op"), | ||
FFI_CudaGraph_Traits); | ||
|
||
size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes, | ||
int32_t *host_group_sizes) { | ||
static std::once_flag init_flag; | ||
static cudaEvent_t d2h_event; | ||
static size_t host_num_gemms; | ||
static const size_t max_num_gemms = 1024; | ||
//static int32_t host_group_sizes_internal[max_num_gemms]; | ||
static int32_t *host_group_sizes_internal = nullptr; | ||
auto init = [&]() { | ||
NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event)); | ||
NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms)); | ||
}; | ||
std::call_once(init_flag, init); | ||
|
||
NVTE_CHECK(dev_group_sizes == nullptr || host_group_sizes == nullptr, | ||
"Only one of dev_group_sizes and host_group_sizes can be non-nullptr."); | ||
|
||
if (dev_group_sizes != nullptr) { | ||
NVTE_CHECK(num_gemms <= max_num_gemms, "num_gemms ", num_gemms, " exceeds the maximum ", | ||
"supported number ", max_num_gemms, " to be downloaded in advance."); | ||
host_num_gemms = num_gemms; | ||
// Wait for current compute stream to finish | ||
cudaStream_t compute_stream_0 = nvte_get_compute_stream(0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mingxu1067 could you check if this causes the same stream sync issue as last time when we used the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a note: this part follows the logic in https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/gemm/cublaslt_gemm.cu#L915 |
||
NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, stream)); | ||
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_stream_0, d2h_event)); | ||
// Async copy group_sizes from device to host | ||
size_t copy_bytes = sizeof(int32_t) * num_gemms; | ||
NVTE_CHECK_CUDA(cudaMemcpyAsync(host_group_sizes_internal, dev_group_sizes, copy_bytes, | ||
cudaMemcpyDeviceToHost, compute_stream_0)); | ||
NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, compute_stream_0)); | ||
return num_gemms; | ||
} | ||
|
||
if (host_group_sizes != nullptr) { | ||
if (host_num_gemms == 0) return 0; | ||
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, | ||
" does not match the previous value ", host_num_gemms, "."); | ||
// Wait for the async copy to finish, then copy group_sizes to user buffer | ||
// Note: This may break cudaGraph. | ||
NVTE_CHECK_CUDA(cudaEventSynchronize(d2h_event)); | ||
memcpy(host_group_sizes, host_group_sizes_internal, sizeof(int32_t) * host_num_gemms); | ||
return host_num_gemms; | ||
} | ||
} | ||
|
||
Error_Type GroupedGemmD2HGroupSizesFFI(cudaStream_t stream, Buffer_Type group_sizes, | ||
Result_Type dummy_output, size_t num_gemms) { | ||
int32_t *dev_group_sizes = reinterpret_cast<int32_t *>(group_sizes.untyped_data()); | ||
GroupedGemmGetGroupSizes(stream, num_gemms, dev_group_sizes, nullptr); | ||
return ffi_with_cuda_error_check(); | ||
} | ||
|
||
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGroupSizesFFI, | ||
FFI::Bind() | ||
.Ctx<FFI_Stream_Type>() // stream | ||
.Arg<Buffer_Type>() // group_sizes | ||
.Ret<Buffer_Type>() // dummy_output | ||
.Attr<int64_t>("num_gemms")); | ||
|
||
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, | ||
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, | ||
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, | ||
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, | ||
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, | ||
bool is_grouped_dense_wgrad) { | ||
bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { | ||
// Notes on matrix layouts and transpose: | ||
// Jax uses row-major data_layout, on entering this function, each input matrix pair: | ||
// A: row-major [m, k] for N - [k, m] for T | ||
|
@@ -406,11 +465,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type | |
|
||
size_t dim_list_bytes = sizeof(int32_t) * num_gemms; | ||
std::vector<int32_t> dim_list_host(num_gemms); | ||
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data()); | ||
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, | ||
stream); | ||
// Note: This may break cudaGraph. | ||
cudaStreamSynchronize(stream); | ||
size_t host_num_gemms = 0; | ||
if (use_async_d2h_group_sizes) { | ||
host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); | ||
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, | ||
" does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); | ||
} else { | ||
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data()); | ||
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, | ||
stream); | ||
// Note: This may break cudaGraph. | ||
cudaStreamSynchronize(stream); | ||
} | ||
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); | ||
if (!is_grouped_dense_wgrad) { | ||
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, | ||
|
@@ -669,7 +735,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, | |
.Attr<bool>("rhs_is_trans") | ||
.Attr<JAXX_Scaling_Mode>("scaling_mode") | ||
.Attr<bool>("has_bias") | ||
.Attr<bool>("is_grouped_dense_wgrad")); | ||
.Attr<bool>("is_grouped_dense_wgrad") | ||
.Attr<bool>("use_async_d2h_group_sizes")); | ||
|
||
} // namespace jax | ||
} // namespace transformer_engine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this causes any issues, we could consider moving this allocation into the FFI prepare phase.