Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,14 +1366,22 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
dtype, input_shape, layout
)
num_gemms = input_shape[0]
_ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))(
group_sizes,
num_gemms=num_gemms,
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)

# jitting grouped_gemm
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
prim_out = jax.jit(
tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
)(
lhs,
rhs,
group_sizes,
contracting_dims,
use_async_d2h_group_sizes=True,
)

self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
Expand Down
87 changes: 85 additions & 2 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"collective_gemm_bootstrap",
"noop_collective_op_set",
"gemm",
"grouped_gemm_copy_group_sizes",
"grouped_gemm",
"gemm_uses_jax_dot",
"sanitize_dims",
Expand Down Expand Up @@ -1234,14 +1235,71 @@ def _te_gemm(
)


class GroupedGemmCopySizesPrimitive(BasePrimitive):
"""
Primitive for async copying group sizes from device to host
"""

name = "te_grouped_gemm_d2h_group_sizes_ffi"
multiple_results = False
impl_static_args = (1,)
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(
group_sizes_aval,
*,
num_gemms,
):
del num_gemms
out_aval = group_sizes_aval
return out_aval

@staticmethod
def outer_abstract(*args, **kwargs):
out = GroupedGemmCopySizesPrimitive.abstract(*args, **kwargs)
return out

@staticmethod
def lowering(
ctx,
group_sizes,
num_gemms,
):
return jax.ffi.ffi_lowering(
GroupedGemmCopySizesPrimitive.name,
operand_output_aliases={0: 0}, # Mark num_gemms as the output
)(
ctx,
group_sizes,
num_gemms=num_gemms,
)

@staticmethod
def impl(
group_sizes,
num_gemms,
):
assert GroupedGemmCopySizesPrimitive.inner_primitive is not None
out = GroupedGemmCopySizesPrimitive.inner_primitive.bind(
group_sizes,
num_gemms=num_gemms,
)
return out


register_primitive(GroupedGemmCopySizesPrimitive)


class GroupedGemmPrimitive(BasePrimitive):
"""
Primitive for grouped GEMM
"""

name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15)
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
inner_primitive = None
outer_primitive = None

Expand All @@ -1264,6 +1322,7 @@ def abstract(
out_dtype,
has_bias,
is_grouped_dense_wgrad,
use_async_d2h_group_sizes,
):
"""
Grouped GEMM operation.
Expand Down Expand Up @@ -1291,7 +1350,7 @@ def abstract(
A jnp.ndarray containing the result of the grouped GEMM operation
"""
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, has_bias
del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes
# TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_alignment_padding = 256
Expand Down Expand Up @@ -1338,6 +1397,7 @@ def lowering(
out_dtype,
has_bias,
is_grouped_dense_wgrad,
use_async_d2h_group_sizes,
):
del out_dtype
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
Expand All @@ -1351,6 +1411,7 @@ def lowering(
scaling_mode=scaling_mode.value,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
)

@staticmethod
Expand All @@ -1371,6 +1432,7 @@ def impl(
out_dtype,
has_bias,
is_grouped_dense_wgrad,
use_async_d2h_group_sizes,
):
assert GroupedGemmPrimitive.inner_primitive is not None
(out, _) = GroupedGemmPrimitive.inner_primitive.bind(
Expand All @@ -1390,6 +1452,7 @@ def impl(
out_dtype=out_dtype,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
)
return (out,)

Expand Down Expand Up @@ -1657,6 +1720,24 @@ def gemm(
return clean_outputs


def grouped_gemm_copy_group_sizes(
group_sizes: jnp.ndarray,
num_gemms: int,
) -> jnp.ndarray:
"""
Async copy group sizes from device to host

Args:
group_sizes: 1D array containing the sizes of each group
num_gemms: number of grouped gemm calls to be made
"""
out = GroupedGemmCopySizesPrimitive.outer_primitive.bind(
group_sizes,
num_gemms=num_gemms,
)
return out


def grouped_gemm(
lhs: Union[jnp.ndarray, GroupedScaledTensor1x],
rhs: Union[jnp.ndarray, GroupedScaledTensor1x],
Expand All @@ -1667,6 +1748,7 @@ def grouped_gemm(
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
use_async_d2h_group_sizes: bool = False,
) -> jnp.ndarray:
"""
Grouped GEMM operation.
Expand Down Expand Up @@ -1850,5 +1932,6 @@ def grouped_gemm(
out_dtype=out_dtype,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
)
return out
1 change: 1 addition & 0 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);

// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);

// Cudnn helpers
Expand Down
81 changes: 74 additions & 7 deletions transformer_engine/jax/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,71 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Attr<JAXX_Collective_Op>("collective_op"),
FFI_CudaGraph_Traits);

size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes,
int32_t *host_group_sizes) {
static std::once_flag init_flag;
static cudaEvent_t d2h_event;
static size_t host_num_gemms;
static const size_t max_num_gemms = 1024;
//static int32_t host_group_sizes_internal[max_num_gemms];
static int32_t *host_group_sizes_internal = nullptr;
auto init = [&]() {
NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event));
NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms));
};
Comment on lines +291 to +294
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this causes any issues, we could consider moving this allocation into the FFI prepare phase.

std::call_once(init_flag, init);

NVTE_CHECK(dev_group_sizes == nullptr || host_group_sizes == nullptr,
"Only one of dev_group_sizes and host_group_sizes can be non-nullptr.");

if (dev_group_sizes != nullptr) {
NVTE_CHECK(num_gemms <= max_num_gemms, "num_gemms ", num_gemms, " exceeds the maximum ",
"supported number ", max_num_gemms, " to be downloaded in advance.");
host_num_gemms = num_gemms;
// Wait for current compute stream to finish
cudaStream_t compute_stream_0 = nvte_get_compute_stream(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mingxu1067 could you check if this causes the same stream sync issue as last time when we used the compute_stream(0) instead of the stream given by XLA?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_stream_0, d2h_event));
// Async copy group_sizes from device to host
size_t copy_bytes = sizeof(int32_t) * num_gemms;
NVTE_CHECK_CUDA(cudaMemcpyAsync(host_group_sizes_internal, dev_group_sizes, copy_bytes,
cudaMemcpyDeviceToHost, compute_stream_0));
NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, compute_stream_0));
return num_gemms;
}

if (host_group_sizes != nullptr) {
if (host_num_gemms == 0) return 0;
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms,
" does not match the previous value ", host_num_gemms, ".");
// Wait for the async copy to finish, then copy group_sizes to user buffer
// Note: This may break cudaGraph.
NVTE_CHECK_CUDA(cudaEventSynchronize(d2h_event));
memcpy(host_group_sizes, host_group_sizes_internal, sizeof(int32_t) * host_num_gemms);
return host_num_gemms;
}
}

Error_Type GroupedGemmD2HGroupSizesFFI(cudaStream_t stream, Buffer_Type group_sizes,
Result_Type dummy_output, size_t num_gemms) {
int32_t *dev_group_sizes = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
GroupedGemmGetGroupSizes(stream, num_gemms, dev_group_sizes, nullptr);
return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGroupSizesFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // group_sizes
.Ret<Buffer_Type>() // dummy_output
.Attr<int64_t>("num_gemms"));

Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans,
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad) {
bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) {
// Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major [m, k] for N - [k, m] for T
Expand Down Expand Up @@ -406,11 +465,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type

size_t dim_list_bytes = sizeof(int32_t) * num_gemms;
std::vector<int32_t> dim_list_host(num_gemms);
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
size_t host_num_gemms = 0;
if (use_async_d2h_group_sizes) {
host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data());
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms,
" does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, ".");
} else {
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
}
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
Expand Down Expand Up @@ -669,7 +735,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad"));
.Attr<bool>("is_grouped_dense_wgrad")
.Attr<bool>("use_async_d2h_group_sizes"));

} // namespace jax
} // namespace transformer_engine
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ pybind11::dict Registrations() {
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));

// Grouped GEMM
dict["te_grouped_gemm_d2h_group_sizes_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler));
dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
Expand Down