-
Notifications
You must be signed in to change notification settings - Fork 258
Support A/B Quantization in Blockscale GEMM #3343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
| ck_tile::pk_int4_t, | ||
| ck_tile::half_t, | ||
| ck_tile::bf8_t>{}); | ||
| using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the quant group size, we need to have two quant group size: quant group for A and quant group for B.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this example is aiming to have A for 1D and B for 2D.
| }; | ||
|
|
||
| template <typename PrecType> | ||
| struct GemmConfig_ABQuant_Prefill : public GemmConfigBase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could directly use the config as quant prefill no need to add one more.
| has_hot_loop_v, | ||
| tail_number_v>, | ||
| ck_tile::GemmABQuantPipelineProblem<typename TypeConfig::ADataType, | ||
| typename TypeConfig::QDataType, // For AQ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have AQuant and BQuant two types for it.
| GemmConfig::PreshuffleB) | ||
| { | ||
| throw std::runtime_error( | ||
| "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add the ABQuant here?
|
|
||
| static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN; | ||
| static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK; | ||
| static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we are having two different quant size, we should have 1 AQKPerBlock and 1 BQKPerBlock. When AQKPerBlock corresponds with the BQKPerBlock we could merge the loop.
| if constexpr(Traits::TransposeC) // transposed C | ||
| { | ||
| index_t reg_offset = | ||
| Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will not provide the PreshuffleQuant to A matrix as it is not the weight.
| struct AQPicker | ||
| { | ||
| CK_TILE_DEVICE | ||
| AQPicker(AQBlockTensor& aq_block_tensor_) : aq_block_tensor(aq_block_tensor_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this part, we could also have a common function together with the include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp
| // Create DRAM tile window for AQ | ||
| template <typename AQDramBlockWindowTmp> | ||
| CK_TILE_DEVICE constexpr auto | ||
| GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the a_quant pipeline so we could merge it.
| // Create DRAM tile window for BQ | ||
| template <typename BQDramBlockWindowTmp> | ||
| CK_TILE_DEVICE constexpr auto | ||
| GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the BQuant pipeline so we could merge it.
|
|
||
| namespace ck_tile { | ||
|
|
||
| struct GemmABQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could share the policy with AQuant and BQuant.
Proposed changes
This commit introduces support for input (A) and weight (B) quantization within the Blockscale GEMM kernel pipeline.
Motivation:
This feature is essential for high-performance inference of large language models (LLMs), as it allows us to utilize 8-bit or 4-bit data types for both activation and weight tensors. By quantizing both A and B.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered