diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 7a4fa268af..124e0248b6 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1366,14 +1366,22 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( dtype, input_shape, layout ) + num_gemms = input_shape[0] + _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( + group_sizes, + num_gemms=num_gemms, + ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm - prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + prim_out = jax.jit( + tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") + )( lhs, rhs, group_sizes, contracting_dims, + use_async_d2h_group_sizes=True, ) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 865efe89da..7fe433bcc6 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -58,6 +58,7 @@ "collective_gemm_bootstrap", "noop_collective_op_set", "gemm", + "grouped_gemm_copy_group_sizes", "grouped_gemm", "gemm_uses_jax_dot", "sanitize_dims", @@ -1237,6 +1238,63 @@ def _te_gemm( ) +class GroupedGemmCopySizesPrimitive(BasePrimitive): + """ + Primitive for async copying group sizes from device to host + """ + + name = "te_grouped_gemm_d2h_group_sizes_ffi" + multiple_results = False + impl_static_args = (1,) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + group_sizes_aval, + *, + num_gemms, + ): + del num_gemms + out_aval = group_sizes_aval + return out_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + out = GroupedGemmCopySizesPrimitive.abstract(*args, **kwargs) + return out + + @staticmethod + def lowering( + ctx, + group_sizes, + num_gemms, + ): + return jax.ffi.ffi_lowering( + GroupedGemmCopySizesPrimitive.name, + operand_output_aliases={0: 0}, # Mark num_gemms as the output + )( + ctx, + group_sizes, + num_gemms=num_gemms, + ) + + @staticmethod + def impl( + group_sizes, + num_gemms, + ): + assert GroupedGemmCopySizesPrimitive.inner_primitive is not None + out = GroupedGemmCopySizesPrimitive.inner_primitive.bind( + group_sizes, + num_gemms=num_gemms, + ) + return out + + +register_primitive(GroupedGemmCopySizesPrimitive) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -1244,7 +1302,7 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) inner_primitive = None outer_primitive = None @@ -1267,6 +1325,7 @@ def abstract( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): """ Grouped GEMM operation. @@ -1294,7 +1353,7 @@ def abstract( A jnp.ndarray containing the result of the grouped GEMM operation """ del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval - del K, lhs_is_trans, rhs_is_trans, has_bias + del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_alignment_padding = 256 @@ -1341,6 +1400,7 @@ def lowering( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( @@ -1354,6 +1414,7 @@ def lowering( scaling_mode=scaling_mode.value, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) @staticmethod @@ -1374,6 +1435,7 @@ def impl( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): assert GroupedGemmPrimitive.inner_primitive is not None (out, _) = GroupedGemmPrimitive.inner_primitive.bind( @@ -1393,6 +1455,7 @@ def impl( out_dtype=out_dtype, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) return (out,) @@ -1661,6 +1724,24 @@ def gemm( return clean_outputs +def grouped_gemm_copy_group_sizes( + group_sizes: jnp.ndarray, + num_gemms: int, +) -> jnp.ndarray: + """ + Async copy group sizes from device to host + + Args: + group_sizes: 1D array containing the sizes of each group + num_gemms: number of grouped gemm calls to be made + """ + out = GroupedGemmCopySizesPrimitive.outer_primitive.bind( + group_sizes, + num_gemms=num_gemms, + ) + return out + + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], @@ -1671,6 +1752,7 @@ def grouped_gemm( preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, quantizer_set: QuantizerSet = noop_quantizer_set, + use_async_d2h_group_sizes: bool = False, ) -> jnp.ndarray: """ Grouped GEMM operation. @@ -1854,5 +1936,6 @@ def grouped_gemm( out_dtype=out_dtype, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index bbfc62120a..3ce6dee731 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -135,6 +135,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); // Cudnn helpers diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f2007efcf6..993ec1377d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -284,12 +284,71 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("collective_op"), FFI_CudaGraph_Traits); +size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes, + int32_t *host_group_sizes) { + static std::once_flag init_flag; + static cudaEvent_t d2h_event; + static size_t host_num_gemms; + static const size_t max_num_gemms = 1024; + //static int32_t host_group_sizes_internal[max_num_gemms]; + static int32_t *host_group_sizes_internal = nullptr; + auto init = [&]() { + NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event)); + NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms)); + }; + std::call_once(init_flag, init); + + NVTE_CHECK(dev_group_sizes == nullptr || host_group_sizes == nullptr, + "Only one of dev_group_sizes and host_group_sizes can be non-nullptr."); + + if (dev_group_sizes != nullptr) { + NVTE_CHECK(num_gemms <= max_num_gemms, "num_gemms ", num_gemms, " exceeds the maximum ", + "supported number ", max_num_gemms, " to be downloaded in advance."); + host_num_gemms = num_gemms; + // Wait for current compute stream to finish + cudaStream_t compute_stream_0 = nvte_get_compute_stream(0); + NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_stream_0, d2h_event)); + // Async copy group_sizes from device to host + size_t copy_bytes = sizeof(int32_t) * num_gemms; + NVTE_CHECK_CUDA(cudaMemcpyAsync(host_group_sizes_internal, dev_group_sizes, copy_bytes, + cudaMemcpyDeviceToHost, compute_stream_0)); + NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, compute_stream_0)); + return num_gemms; + } + + if (host_group_sizes != nullptr) { + if (host_num_gemms == 0) return 0; + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the previous value ", host_num_gemms, "."); + // Wait for the async copy to finish, then copy group_sizes to user buffer + // Note: This may break cudaGraph. + NVTE_CHECK_CUDA(cudaEventSynchronize(d2h_event)); + memcpy(host_group_sizes, host_group_sizes_internal, sizeof(int32_t) * host_num_gemms); + return host_num_gemms; + } +} + +Error_Type GroupedGemmD2HGroupSizesFFI(cudaStream_t stream, Buffer_Type group_sizes, + Result_Type dummy_output, size_t num_gemms) { + int32_t *dev_group_sizes = reinterpret_cast(group_sizes.untyped_data()); + GroupedGemmGetGroupSizes(stream, num_gemms, dev_group_sizes, nullptr); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGroupSizesFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // group_sizes + .Ret() // dummy_output + .Attr("num_gemms")); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad) { + bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -410,11 +469,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); if (!is_grouped_dense_wgrad) { NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, @@ -673,7 +739,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") - .Attr("is_grouped_dense_wgrad")); + .Attr("is_grouped_dense_wgrad") + .Attr("use_async_d2h_group_sizes")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 23d46b3384..f6b1acd439 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -69,6 +69,9 @@ pybind11::dict Registrations() { pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); // Grouped GEMM + dict["te_grouped_gemm_d2h_group_sizes_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler)); dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));