Skip to content

Commit f99e6bb

Browse files
slayton58Chao1Han
authored andcommitted
Add scaled_grouped_mm_v2 and python API (pytorch#165154)
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <[email protected]> Pull Request resolved: pytorch#165154 Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
1 parent 40de942 commit f99e6bb

File tree

7 files changed

+334
-14
lines changed

7 files changed

+334
-14
lines changed

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,7 +2578,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
25782578
const Tensor& mat_a,
25792579
const Tensor& mat_b,
25802580
const Tensor& scale_a,
2581+
const SwizzleType& swizzle_a,
25812582
const Tensor& scale_b,
2583+
const SwizzleType& swizzle_b,
25822584
const std::optional<at::Tensor>& offs,
25832585
Tensor& out) {
25842586
const bool a_is_2d = mat_a.dim() == 2;
@@ -2589,6 +2591,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
25892591
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
25902592
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
25912593
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
2594+
// MXFP8 expects float8_e8m0fnu scales.
2595+
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
2596+
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
2597+
#ifdef USE_ROCM
2598+
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
2599+
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
2600+
#else
2601+
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
2602+
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
2603+
#endif
25922604

25932605
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
25942606
fbgemm_gpu::mx8mx8bf16_grouped_mm(
@@ -2673,6 +2685,9 @@ _f8_f8_bf16_rowwise_grouped_mm(
26732685
const std::optional<Tensor>& bias,
26742686
bool use_fast_accum,
26752687
Tensor& out) {
2688+
// FP8 per-tensor and per-row scaling expect fp32 scales.
2689+
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
2690+
"For grouped FP8 rowwise, both scales must be float32 tensors");
26762691
#ifndef USE_ROCM
26772692
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
26782693
mat_a,
@@ -2772,11 +2787,15 @@ _scaled_grouped_mm_cuda(
27722787
#endif
27732788

27742789
if (is_mx8mx8bf16) {
2790+
// Note: Passing implied SwizzleType here, correctness of scale previously checked
2791+
// in `check_scale` call
27752792
return _mx8_mx8_bf16_grouped_mm_fbgemm(
27762793
mat_a,
27772794
mat_b,
27782795
scale_a,
2796+
SwizzleType::SWIZZLE_32_4_4,
27792797
scale_b,
2798+
SwizzleType::SWIZZLE_32_4_4,
27802799
offs.value(),
27812800
out);
27822801
}
@@ -2793,6 +2812,140 @@ _scaled_grouped_mm_cuda(
27932812
out);
27942813
}
27952814

2815+
namespace {
2816+
2817+
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
2818+
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
2819+
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
2820+
2821+
} // anonymous namespace
2822+
2823+
Tensor
2824+
_scaled_grouped_mm_cuda_v2(
2825+
const Tensor& mat_a, const Tensor& mat_b,
2826+
ArrayRef<Tensor> scale_a,
2827+
IntArrayRef scale_recipe_a,
2828+
IntArrayRef swizzle_a,
2829+
ArrayRef<Tensor> scale_b,
2830+
IntArrayRef scale_recipe_b,
2831+
IntArrayRef swizzle_b,
2832+
const std::optional<Tensor>& offs,
2833+
const std::optional<Tensor>& bias,
2834+
const std::optional<c10::ScalarType> out_dtype,
2835+
IntArrayRef contraction_dim,
2836+
bool use_fast_accum) {
2837+
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
2838+
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
2839+
2840+
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
2841+
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
2842+
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
2843+
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
2844+
const bool a_is_2d = mat_a.dim() == 2;
2845+
const bool b_is_2d = mat_b.dim() == 2;
2846+
2847+
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
2848+
if (!a_is_2d || !b_is_2d) {
2849+
if (contraction_dim.size() > 0) {
2850+
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
2851+
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
2852+
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
2853+
mat_b.size(dim_b));
2854+
// Note: only (-1, -2) is currently supported
2855+
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
2856+
} else {
2857+
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
2858+
}
2859+
}
2860+
TORCH_CHECK_VALUE(
2861+
mat_a.size(-1) % 16 == 0,
2862+
"Expected trailing dimension of mat_a to be divisible by 16 ",
2863+
"but got mat1 shape: (",
2864+
mat_a.sizes(),
2865+
").");
2866+
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
2867+
"Expected mat_b shape to be divisible by 16 ",
2868+
"but got mat_b shape: (",
2869+
mat_b.sizes(),
2870+
").");
2871+
2872+
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
2873+
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
2874+
2875+
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
2876+
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
2877+
// routines
2878+
if (offs.has_value()) {
2879+
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
2880+
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
2881+
}
2882+
2883+
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
2884+
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
2885+
2886+
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
2887+
2888+
// Conversion of implicitly-defined enums to explicit
2889+
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
2890+
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
2891+
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
2892+
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
2893+
2894+
// at this point we can start working out what we want to be doing
2895+
// Try to do as few steps as possible.
2896+
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
2897+
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
2898+
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
2899+
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
2900+
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
2901+
bool ok = accept_fn(mat_a.scalar_type(),
2902+
scale_recipe_a_enum,
2903+
scale_a,
2904+
mat_b.scalar_type(),
2905+
scale_recipe_b_enum,
2906+
scale_b);
2907+
if (ok) {
2908+
gemm_impl = scaled_gemm_impl;
2909+
break;
2910+
}
2911+
}
2912+
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
2913+
"No gemm implementation was found");
2914+
2915+
switch (gemm_impl) {
2916+
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
2917+
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
2918+
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
2919+
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
2920+
return _f8_f8_bf16_rowwise_grouped_mm(
2921+
mat_a,
2922+
mat_b,
2923+
scale_a[0],
2924+
scale_b[0],
2925+
offs,
2926+
bias,
2927+
use_fast_accum,
2928+
out);
2929+
}
2930+
case ScaledGemmImplementation::MXFP8_MXFP8: {
2931+
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
2932+
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
2933+
return _mx8_mx8_bf16_grouped_mm_fbgemm(
2934+
mat_a,
2935+
mat_b,
2936+
scale_a[0],
2937+
swizzle_a_enum[0],
2938+
scale_b[0],
2939+
swizzle_b_enum[0],
2940+
offs.value(),
2941+
out);
2942+
}
2943+
default:
2944+
TORCH_CHECK_NOT_IMPLEMENTED(false,
2945+
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
2946+
}
2947+
}
2948+
27962949
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
27972950
const std::optional<at::Tensor>& offs,
27982951
const std::optional<at::Tensor>& bias,

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7183,6 +7183,12 @@
71837183
CUDA: _scaled_grouped_mm_cuda
71847184
tags: needs_exact_strides
71857185

7186+
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
7187+
variants: function
7188+
dispatch:
7189+
CUDA: _scaled_grouped_mm_cuda_v2
7190+
tags: needs_exact_strides
7191+
71867192
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
71877193
variants: function
71887194
dispatch:

docs/source/nn.functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,4 @@ Low-Precision functions
228228
ScalingType
229229
SwizzleType
230230
scaled_mm
231+
scaled_grouped_mm

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ aten::_scaled_dot_product_flash_attention_for_cpu_backward
524524
aten::_scaled_dot_product_fused_attention_overrideable
525525
aten::_scaled_dot_product_fused_attention_overrideable_backward
526526
aten::_scaled_grouped_mm
527+
aten::_scaled_grouped_mm_v2
527528
aten::_scaled_mm
528529
aten::_scaled_mm.out
529530
aten::_scaled_mm_v2

0 commit comments

Comments
 (0)