@@ -2578,7 +2578,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
25782578 const Tensor& mat_a,
25792579 const Tensor& mat_b,
25802580 const Tensor& scale_a,
2581+ const SwizzleType& swizzle_a,
25812582 const Tensor& scale_b,
2583+ const SwizzleType& swizzle_b,
25822584 const std::optional<at::Tensor>& offs,
25832585 Tensor& out) {
25842586 const bool a_is_2d = mat_a.dim () == 2 ;
@@ -2589,6 +2591,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
25892591 TORCH_CHECK_VALUE (is_2d_2d || is_2d_3d, " MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases" );
25902592 TORCH_CHECK_VALUE (offs.has_value (), " MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets" );
25912593 TORCH_CHECK_VALUE (out.scalar_type () == at::kBFloat16 , " Only bf16 out_dtype is supported for MXFP8 grouped gemm" );
2594+ // MXFP8 expects float8_e8m0fnu scales.
2595+ TORCH_CHECK_VALUE (scale_a.scalar_type () == at::kFloat8_e8m0fnu && scale_b.scalar_type () == at::kFloat8_e8m0fnu ,
2596+ " For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors." );
2597+ #ifdef USE_ROCM
2598+ TORCH_CHECK_VALUE (swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
2599+ " For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE" );
2600+ #else
2601+ TORCH_CHECK_VALUE (swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
2602+ " For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4" );
2603+ #endif
25922604
25932605#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
25942606 fbgemm_gpu::mx8mx8bf16_grouped_mm (
@@ -2673,6 +2685,9 @@ _f8_f8_bf16_rowwise_grouped_mm(
26732685 const std::optional<Tensor>& bias,
26742686 bool use_fast_accum,
26752687 Tensor& out) {
2688+ // FP8 per-tensor and per-row scaling expect fp32 scales.
2689+ TORCH_CHECK_VALUE (scale_a.scalar_type () == kFloat && scale_b.scalar_type () == kFloat ,
2690+ " For grouped FP8 rowwise, both scales must be float32 tensors" );
26762691#ifndef USE_ROCM
26772692 return _f8_f8_bf16_rowwise_grouped_mm_cuda (
26782693 mat_a,
@@ -2772,11 +2787,15 @@ _scaled_grouped_mm_cuda(
27722787#endif
27732788
27742789 if (is_mx8mx8bf16) {
2790+ // Note: Passing implied SwizzleType here, correctness of scale previously checked
2791+ // in `check_scale` call
27752792 return _mx8_mx8_bf16_grouped_mm_fbgemm (
27762793 mat_a,
27772794 mat_b,
27782795 scale_a,
2796+ SwizzleType::SWIZZLE_32_4_4,
27792797 scale_b,
2798+ SwizzleType::SWIZZLE_32_4_4,
27802799 offs.value (),
27812800 out);
27822801 }
@@ -2793,6 +2812,140 @@ _scaled_grouped_mm_cuda(
27932812 out);
27942813}
27952814
2815+ namespace {
2816+
2817+ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2 > scale_grouped_kernel_dispatch = {{
2818+ { " rowwise_rowwise" , check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
2819+ { " mxfp8_mxfp8" , check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
2820+
2821+ } // anonymous namespace
2822+
2823+ Tensor
2824+ _scaled_grouped_mm_cuda_v2 (
2825+ const Tensor& mat_a, const Tensor& mat_b,
2826+ ArrayRef<Tensor> scale_a,
2827+ IntArrayRef scale_recipe_a,
2828+ IntArrayRef swizzle_a,
2829+ ArrayRef<Tensor> scale_b,
2830+ IntArrayRef scale_recipe_b,
2831+ IntArrayRef swizzle_b,
2832+ const std::optional<Tensor>& offs,
2833+ const std::optional<Tensor>& bias,
2834+ const std::optional<c10::ScalarType> out_dtype,
2835+ IntArrayRef contraction_dim,
2836+ bool use_fast_accum) {
2837+ bool allowed_device = _scaled_mm_allowed_device (/* sm90_only*/ true , /* sm100_only*/ true );
2838+ TORCH_CHECK_VALUE (allowed_device, " torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+" );
2839+
2840+ TORCH_CHECK_VALUE (!check_valid_strides_and_return_transposed (mat_a), " Expected mat1 to not be transposed" );
2841+ TORCH_CHECK_VALUE (check_valid_strides_and_return_transposed (mat_b), " Expected mat2 to be transposed" );
2842+ TORCH_CHECK_VALUE (mat_a.dim () == 2 || mat_a.dim () == 3 , " mat_a has to be 2 or 3d" );
2843+ TORCH_CHECK_VALUE (mat_b.dim () == 2 || mat_b.dim () == 3 , " mat_b has to be 2 or 3d" );
2844+ const bool a_is_2d = mat_a.dim () == 2 ;
2845+ const bool b_is_2d = mat_b.dim () == 2 ;
2846+
2847+ // NOTE(slayton): For sub-1B formats want contraction_dim argument?
2848+ if (!a_is_2d || !b_is_2d) {
2849+ if (contraction_dim.size () > 0 ) {
2850+ const int dim_a = contraction_dim[0 ], dim_b = mat_b.size (contraction_dim[1 ]);
2851+ TORCH_CHECK_VALUE (mat_a.size (dim_a) == mat_b.size (dim_b),
2852+ " Contraction dimensions (" , dim_a, " ," , dim_b, " ) of mat_a and mat_b must match, got: " , mat_a.size (dim_a), " and " ,
2853+ mat_b.size (dim_b));
2854+ // Note: only (-1, -2) is currently supported
2855+ TORCH_CHECK_VALUE (dim_a == -1 && dim_b == -2 , " Curently contraction dims must be (-1, -2) only" );
2856+ } else {
2857+ TORCH_CHECK_VALUE (mat_a.size (-1 ) == mat_b.size (-2 ), " contraction dimension of mat_a and mat_b must match" );
2858+ }
2859+ }
2860+ TORCH_CHECK_VALUE (
2861+ mat_a.size (-1 ) % 16 == 0 ,
2862+ " Expected trailing dimension of mat_a to be divisible by 16 " ,
2863+ " but got mat1 shape: (" ,
2864+ mat_a.sizes (),
2865+ " )." );
2866+ TORCH_CHECK_VALUE (mat_b.size (-2 ) % 16 == 0 && mat_b.size (-1 ) % 16 == 0 ,
2867+ " Expected mat_b shape to be divisible by 16 " ,
2868+ " but got mat_b shape: (" ,
2869+ mat_b.sizes (),
2870+ " )." );
2871+
2872+ TORCH_CHECK_VALUE (!bias.has_value (), " Bias not supported yet" );
2873+ TORCH_CHECK_VALUE (offs.has_value () == (a_is_2d || b_is_2d), " Have to provide offsets if there is a 2d matrix" );
2874+
2875+ // NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
2876+ // for rowwise, no offsets implies 3d-3d and is handled by lower-level
2877+ // routines
2878+ if (offs.has_value ()) {
2879+ TORCH_CHECK_VALUE (offs->dim () == 1 , " offs has to be 1D" );
2880+ TORCH_CHECK_VALUE (offs->dtype () == at::kInt , " Offsets have to be int32" );
2881+ }
2882+
2883+ const auto out_dtype_ = out_dtype.value_or (kBFloat16 );
2884+ TORCH_CHECK_VALUE (out_dtype_ == kBFloat16 , " Only bf16 high precision output types are supported for grouped gemm" );
2885+
2886+ Tensor out = create_grouped_gemm_output_tensor (mat_a, mat_b, offs, out_dtype_);
2887+
2888+ // Conversion of implicitly-defined enums to explicit
2889+ auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
2890+ auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
2891+ auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
2892+ auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
2893+
2894+ // at this point we can start working out what we want to be doing
2895+ // Try to do as few steps as possible.
2896+ // NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
2897+ // Do this via a list of defined (name, acceptance, concrete_impl) tuples.
2898+ ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
2899+ for (const auto & fn_entry : scale_grouped_kernel_dispatch) {
2900+ const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
2901+ bool ok = accept_fn (mat_a.scalar_type (),
2902+ scale_recipe_a_enum,
2903+ scale_a,
2904+ mat_b.scalar_type (),
2905+ scale_recipe_b_enum,
2906+ scale_b);
2907+ if (ok) {
2908+ gemm_impl = scaled_gemm_impl;
2909+ break ;
2910+ }
2911+ }
2912+ TORCH_CHECK_VALUE (gemm_impl != ScaledGemmImplementation::NONE,
2913+ " No gemm implementation was found" );
2914+
2915+ switch (gemm_impl) {
2916+ case ScaledGemmImplementation::ROWWISE_ROWWISE: {
2917+ const int scale_multiplier = (mat_a.dim () == 2 && mat_b.dim () == 2 ) ? offs->size (0 ) : 1 ;
2918+ _check_scales_fp8_rowwise (mat_a, scale_a[0 ], 0 /* dim */ , 0 /* arg_idx */ , scale_multiplier);
2919+ _check_scales_fp8_rowwise (mat_b, scale_b[0 ], 1 /* dim */ , 1 /* arg_idx */ , scale_multiplier);
2920+ return _f8_f8_bf16_rowwise_grouped_mm (
2921+ mat_a,
2922+ mat_b,
2923+ scale_a[0 ],
2924+ scale_b[0 ],
2925+ offs,
2926+ bias,
2927+ use_fast_accum,
2928+ out);
2929+ }
2930+ case ScaledGemmImplementation::MXFP8_MXFP8: {
2931+ _check_scales_mxfp8 (mat_a, scale_a[0 ], 0 /* dim */ , 0 /* arg_idx */ );
2932+ _check_scales_mxfp8 (mat_b, scale_b[0 ], 1 /* dim */ , 1 /* arg_idx */ );
2933+ return _mx8_mx8_bf16_grouped_mm_fbgemm (
2934+ mat_a,
2935+ mat_b,
2936+ scale_a[0 ],
2937+ swizzle_a_enum[0 ],
2938+ scale_b[0 ],
2939+ swizzle_b_enum[0 ],
2940+ offs.value (),
2941+ out);
2942+ }
2943+ default :
2944+ TORCH_CHECK_NOT_IMPLEMENTED (false ,
2945+ " _scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here" );
2946+ }
2947+ }
2948+
27962949Tensor _grouped_mm_cuda (const Tensor& mat_a, const Tensor& mat_b,
27972950const std::optional<at::Tensor>& offs,
27982951const std::optional<at::Tensor>& bias,
0 commit comments