Skip to content

Commit 1e38791

Browse files
rosenrodtdominicshanshan
authored andcommitted
fix: MoE autotune fallback failed to query default heuristic (NVIDIA#5520)
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
1 parent a820cee commit 1e38791

File tree

1 file changed

+10
-2
lines changed
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe

1 file changed

+10
-2
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,17 @@ std::vector<int64_t> Runner::getValidConfigIndices(
469469
int64_t Runner::getDefaultValidConfigIndex(
470470
int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numLocalExperts, int32_t numTokens) const
471471
{
472-
auto const validIndices = getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
473472

474-
return validIndices[0];
473+
int32_t indexGemm1
474+
= mPermuteGemm1.getDefaultValidConfigIndex(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
475+
int32_t indexGemm2
476+
= mGemm2.getDefaultValidConfigIndex(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
477+
478+
auto it = std::find_if(mPassingConfigs.begin(), mPassingConfigs.end(),
479+
[indexGemm1, indexGemm2](MoEConfig cfg)
480+
{ return (cfg.gemm1Config == indexGemm1 && cfg.gemm2Config == indexGemm2); });
481+
TLLM_CHECK_WITH_INFO(it != mPassingConfigs.end(), "No compatible configs found for the block scale MoE runner.");
482+
return std::distance(mPassingConfigs.begin(), it);
475483
}
476484

477485
void Runner::run(

0 commit comments

Comments
 (0)