Skip to content

Conversation

@diptorupd
Copy link
Collaborator

The PR does initial refactoring to improve the documentation and code structure of the MMA ops in prefill.cuh

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR performs an initial refactoring to improve the documentation and code structure of MMA (Matrix Multiply-Accumulate) operations used in attention kernels. The refactoring extracts MMA layout definitions into dedicated header files and updates the prefill kernel to use the new abstractions.

Key Changes:

  • Introduces platform-agnostic MMA layout type definitions with backend-specific implementations for HIP and CUDA
  • Replaces hard-coded constants with named constants from layout structs (e.g., HALF_ELEMS_PER_THREADFRAG_SIZE, VECTOR_BIT_WIDTHFRAG_BITWIDTH)
  • Adds comprehensive documentation for fragment layouts and thread-to-element ownership patterns

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 13 comments.

File Description
libflashinfer/include/gpu_iface/mma_layout.hpp Defines base templates and enums for MMA layout abstractions with platform-agnostic interface
libflashinfer/include/gpu_iface/backend/hip/mma_layout_hip.h Implements HIP/CDNA3-specific MFMA layout specializations for M16N16K16 operations
libflashinfer/include/gpu_iface/backend/cuda/mma_layout_cuda.h Implements CUDA Tensor Core layout specializations for M16N8K16 and M16N16K16 operations
libflashinfer/include/flashinfer/attention/generic/prefill.cuh Updates attention kernel to use new MMA layout constants and adds mma_layout.hpp include

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

};

template <>
struct wmma_c_layout<WmmaOp::M16N8K16, __nv_bfloat16> {};
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty template specialization for wmma_c_layout&lt;WmmaOp::M16N8K16, __nv_bfloat16&gt; is declared but not implemented. If bfloat16 support is not yet implemented, consider adding a comment explaining this is a placeholder for future implementation, or use static_assert to provide a clear error message if this specialization is used.

Suggested change
struct wmma_c_layout<WmmaOp::M16N8K16, __nv_bfloat16> {};
struct wmma_c_layout<WmmaOp::M16N8K16, __nv_bfloat16> {
// Placeholder: bfloat16 support for wmma_c_layout<M16N8K16> is not yet implemented.
static_assert(sizeof(__nv_bfloat16) == 0,
"wmma_c_layout<WmmaOp::M16N8K16, __nv_bfloat16> is not implemented yet. bfloat16 support is planned for future releases.");
};

Copilot uses AI. Check for mistakes.
Comment on lines +270 to +272
struct wmma_a_layout<WmmaOp::M16N16K16, __nv_bfloat16> {};

/*!
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty template specialization for wmma_a_layout&lt;WmmaOp::M16N16K16, __nv_bfloat16&gt; is declared but not implemented. If bfloat16 support is not yet implemented, consider adding a comment explaining this is a placeholder for future implementation, or use static_assert to provide a clear error message if this specialization is used.

Suggested change
struct wmma_a_layout<WmmaOp::M16N16K16, __nv_bfloat16> {};
/*!
struct wmma_a_layout<WmmaOp::M16N16K16, __nv_bfloat16> {
// bfloat16 support for M16N16K16 is not yet implemented.
static_assert(sizeof(__nv_bfloat16) == 0,
"wmma_a_layout<WmmaOp::M16N16K16, __nv_bfloat16> is not implemented: bfloat16 support is not yet available.");
};

Copilot uses AI. Check for mistakes.
};

template <>
struct wmma_b_layout<WmmaOp::M16N16K16, __nv_bfloat16> {};
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty template specialization for wmma_b_layout&lt;WmmaOp::M16N16K16, __nv_bfloat16&gt; is declared but not implemented. If bfloat16 support is not yet implemented, consider adding a comment explaining this is a placeholder for future implementation, or use static_assert to provide a clear error message if this specialization is used.

Suggested change
struct wmma_b_layout<WmmaOp::M16N16K16, __nv_bfloat16> {};
struct wmma_b_layout<WmmaOp::M16N16K16, __nv_bfloat16> {
// bfloat16 support for wmma_b_layout<WmmaOp::M16N16K16, __nv_bfloat16> is not yet implemented.
// This is a placeholder specialization. If you see this error, bfloat16 support must be added.
static_assert(sizeof(__nv_bfloat16) == 0,
"wmma_b_layout<WmmaOp::M16N16K16, __nv_bfloat16> is not implemented: bfloat16 support is not available yet.");
};

Copilot uses AI. Check for mistakes.
Comment on lines +106 to +107
template <>
struct wmma_a_layout<WmmaOp::M16N16K16, hip_bfloat16> {};
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty template specialization for wmma_a_layout&lt;WmmaOp::M16N16K16, hip_bfloat16&gt; is declared but not implemented. If bfloat16 support is not yet implemented, consider adding a comment explaining this is a placeholder for future implementation, or use static_assert to provide a clear error message if this specialization is used.

Suggested change
template <>
struct wmma_a_layout<WmmaOp::M16N16K16, hip_bfloat16> {};
// Placeholder: bfloat16 support for wmma_a_layout is not yet implemented.
template <>
struct wmma_a_layout<WmmaOp::M16N16K16, hip_bfloat16> {
static_assert(sizeof(hip_bfloat16) == 0,
"wmma_a_layout<WmmaOp::M16N16K16, hip_bfloat16> is not implemented: bfloat16 support is not yet available for this specialization.");
};

Copilot uses AI. Check for mistakes.
Comment on lines +177 to +178
template <>
struct wmma_b_layout<WmmaOp::M16N16K16, hip_bfloat16> {};
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty template specialization for wmma_b_layout&lt;WmmaOp::M16N16K16, hip_bfloat16&gt; is declared but not implemented. If bfloat16 support is not yet implemented, consider adding a comment explaining this is a placeholder for future implementation, or use static_assert to provide a clear error message if this specialization is used.

Suggested change
template <>
struct wmma_b_layout<WmmaOp::M16N16K16, hip_bfloat16> {};
// Placeholder: bfloat16 support for wmma_b_layout is not yet implemented.
template <>
struct wmma_b_layout<WmmaOp::M16N16K16, hip_bfloat16> {
static_assert(sizeof(hip_bfloat16) == 0,
"wmma_b_layout<WmmaOp::M16N16K16, hip_bfloat16> is not implemented: bfloat16 support is not yet available.");
};

Copilot uses AI. Check for mistakes.
Comment on lines +349 to +350
template <>
struct wmma_c_layout<WmmaOp::M16N16K16, __nv_bfloat16> {};
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty template specialization for wmma_c_layout&lt;WmmaOp::M16N16K16, __nv_bfloat16&gt; is declared but not implemented. If bfloat16 support is not yet implemented, consider adding a comment explaining this is a placeholder for future implementation, or use static_assert to provide a clear error message if this specialization is used.

Suggested change
template <>
struct wmma_c_layout<WmmaOp::M16N16K16, __nv_bfloat16> {};
// Placeholder: bfloat16 support for wmma_c_layout<WmmaOp::M16N16K16, __nv_bfloat16> is not yet implemented.
template <>
struct wmma_c_layout<WmmaOp::M16N16K16, __nv_bfloat16> {
static_assert(sizeof(__nv_bfloat16) == 0,
"wmma_c_layout<WmmaOp::M16N16K16, __nv_bfloat16> is not implemented yet. bfloat16 support is planned for a future release.");
};

Copilot uses AI. Check for mistakes.
using IdType = IdType_;
using AttentionVariant = AttentionVariant_;

static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * WARP_SIZE;
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The newly added line uses WARP_SIZE without qualification, but this constant may not be defined in the current scope. The old code used platform-specific values (64 for HIP, 32 for CUDA). Consider using a qualified constant like gpu_iface::WARP_SIZE or defining WARP_SIZE within KernelTraits, or keeping the platform-specific approach that was removed.

Suggested change
static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * WARP_SIZE;
static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * gpu_iface::kWarpSize;

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,113 @@
// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc.
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The title "Intial refactoring for mma layout types." contains a spelling error. "Intial" should be "Initial".

Copilot uses AI. Check for mistakes.
M16N16K16, // Logical 16×16 output tile, k-dim=16 (may be composite)
};

/*! * \brief Generic WMMA layout providing generic op-level traits
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on line 41 has a formatting issue. There's an extra space and asterisk at the start: /*! * \brief should be /*! \brief (remove the extra space and asterisk after the opening /*!).

Suggested change
/*! * \brief Generic WMMA layout providing generic op-level traits
/*! \brief Generic WMMA layout providing generic op-level traits

Copilot uses AI. Check for mistakes.
};

template <>
struct wmma_b_layout<WmmaOp::M16N8K16, __nv_bfloat16> {};
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty template specialization for wmma_b_layout&lt;WmmaOp::M16N8K16, __nv_bfloat16&gt; is declared but not implemented. If bfloat16 support is not yet implemented, consider adding a comment explaining this is a placeholder for future implementation, or use static_assert to provide a clear error message if this specialization is used.

Suggested change
struct wmma_b_layout<WmmaOp::M16N8K16, __nv_bfloat16> {};
struct wmma_b_layout<WmmaOp::M16N8K16, __nv_bfloat16> {
// Placeholder: bfloat16 support for wmma_b_layout<M16N8K16> is not yet implemented.
static_assert(
sizeof(__nv_bfloat16) == 0,
"wmma_b_layout<WmmaOp::M16N8K16, __nv_bfloat16> is not implemented: bfloat16 support is not yet available for this operation.");
};

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant