Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,6 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
}
}

// FIXME: Disable split-k for now.
if (options.mClusterDimZ != 1)
{
continue;
}

if (options.mFusedAct)
{
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType))
Expand All @@ -158,14 +152,29 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
}
}

// FIXME: Disables a few static scheduler kernels (schedS) that appears to have issues;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have nvbug to trace it? Add nvbug ID if existed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have the respective nvbug. This happens during this PR.

// found after commit e257cb3533; still under investigation. Offending kernels:
// bmm_E2m1_E2m1E2m1_Fp32_t128x64x256_s6_et128x64_m128x64x64_cga1x1x1_16dp256b_TN_transOut_schedS_bN_ldgsts_tmaOpt_clmp_swiGlu_dynBatch_sm100a
// bmm_MxE4m3_MxE2m1MxE4m3_Fp32_t128x64x256_s3_et128x64_m128x64x32_cga1x1x1_16dp256b_TN_transOut_schedS_biasM_bN_ldgsts_tmaOpt_clmp_swiGlu_dynBatch_sm100f
if (options.mTileScheduler == TileScheduler::Static && options.mUseTmaOobOpt == true
&& options.mTileN == 64)
{
continue;
}

if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM)
{
mPassingConfigIndices.push_back(i);
}
}
}

TLLM_CHECK_WITH_INFO(!mPassingConfigIndices.empty(), "No kernel found for the given options");
TLLM_CHECK_WITH_INFO(!mPassingConfigIndices.empty(),
"No kernel found for the given options: mDtypeA: %s, mDtypeB: %s, mDtypeC: %s, mUseDeepSeekFp8: %d, "
"mTransposeMmaOutput: %d, mRouteAct: %d, mFusedAct: %d, mIsStaticBatch: %d, mTileSize: %d",
tg::dtypeToString(mOptions.dtypeA).c_str(), tg::dtypeToString(mOptions.dtypeB).c_str(),
tg::dtypeToString(mOptions.dtypeC).c_str(), mOptions.deepSeekFp8, mOptions.transposeMmaOutput,
mOptions.routeAct, mOptions.fusedAct, mOptions.staticBatch, mOptions.tileSize);
}

size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
Expand Down Expand Up @@ -277,7 +286,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
auto envVarVal = std::getenv("TLLM_BATCHED_GEMM_PRINT_NAME");
if (envVarVal && std::atoi(envVarVal) == 1)
{
TLLM_LOG_INFO("numBatches %d Gemm %d %d %d Kernel %s\n", numBatches, m, n, k, config.mFunctionName);
TLLM_LOG_INFO("NumBatches %d, MaxNumCtasInBatchDim %d, ShapeMNK %d %d %d, Kernel %s", numBatches,
maxNumCtasInBatchDim, m, n, k, config.mFunctionName);
}
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ class TrtllmGenBatchedGemmRunner
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device, int32_t configIndex);

// NVFP4 per-block scaling GEMM
// Block-scaling GEMM
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
int32_t configIndex);

// Block-scaling GEMM with SwiGLU activation
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
void const* b, void const* sfB, float const* bias, float const* swiGluAlpha, float const* swiGluBeta,
float const* clampLimit, void* c, void* outSfC, void* workspace, CUstream stream, int device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,9 @@ class BatchedGemmInterface
return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner);
}

// Creates GemmOptions from kernel and data.
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, BatchedGemmData const& data) const;

// Returns the number of CTAs of the current kernel.
int32_t getNumCtas(
BatchedGemmOptions const& options, std::optional<int32_t> maxNumCtasInBatchDim = std::nullopt) const
Expand All @@ -541,9 +544,6 @@ class BatchedGemmInterface
// Returns true if the configuration of the cubin can be executed for the given params.
bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const;

// Creates GemmOptions from kernel and data.
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, BatchedGemmData const& data) const;

private:
// Aligns the pointer to the alignment
template <typename Dtype>
Expand Down
Loading