Skip to content
Merged
Prev Previous commit
review comment
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
  • Loading branch information
Tabrizian committed Sep 26, 2025
commit 66044e8b0f10fc213215b8f782d421cfdb8b88a0
9 changes: 4 additions & 5 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,7 @@ class WindowBlockManager
return mIsSWA;
}

[[nodiscard]] std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByBlockKey(
BlockKey const& blockKey);
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);

//! \brief Unpin blocks by starting from a block id and walking prev pointers.
void unpinBlocksById(KVCacheBlock::IdType blockId);
Expand Down Expand Up @@ -1194,7 +1193,7 @@ class BlockManager
return mWindowBlockManagers.at(windowSize).getBlockById(blockId);
}

[[nodiscard]] std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByBlockKey(
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(
BlockKey const& blockKey, SizeType32 windowSize)
{
return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByBlockKey(blockKey);
Expand Down Expand Up @@ -1491,7 +1490,7 @@ class BaseKVCacheManager

[[nodiscard]] virtual CacheType getCacheType() const = 0;

[[nodiscard]] virtual std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByBlockKey(
[[nodiscard]] virtual std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(
BlockKey const& blockKey, SizeType32 windowSize)
= 0;

Expand Down Expand Up @@ -1794,7 +1793,7 @@ class KVCacheManager : public BaseKVCacheManager
mBlockManager.flushIterationEvents();
}

std::optional<std::shared_ptr<KVCacheBlock>> findBlocksInReuseTreeByBlockKey(
std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(
BlockKey const& blockKey, SizeType32 windowSize) override
{
return mBlockManager.findBlocksInReuseTreeByBlockKey(blockKey, windowSize);
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BlockRange
{
auto const windowSize = firstWindowSize(cacheManager);
// Find the last block in the reuse tree for the provided full sequence of block keys
auto lastBlock = *cacheManager.findBlocksInReuseTreeByBlockKey(lastBlockKey, windowSize);
auto lastBlock = cacheManager.findBlocksInReuseTreeByBlockKey(lastBlockKey, windowSize);
// TODO: handle the case where the last block is not found
TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree");
int32_t const numBlocksToCollect = indexFromEnd + 1;
Expand Down
5 changes: 2 additions & 3 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,8 +1096,7 @@ bool WindowBlockManager::blockInRadixTree(BlockPtr const& block)
return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr;
}

std::optional<std::shared_ptr<KVCacheBlock>> WindowBlockManager::findBlocksInReuseTreeByBlockKey(
BlockKey const& blockKey)
std::shared_ptr<KVCacheBlock> WindowBlockManager::findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey)
{
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
auto blockedUniqueTokens
Expand All @@ -1118,7 +1117,7 @@ std::optional<std::shared_ptr<KVCacheBlock>> WindowBlockManager::findBlocksInReu

if (matchingBlock == nullptr)
{
return std::nullopt;
return nullptr;
}

searchRoot = std::move(matchingBlock);
Expand Down
5 changes: 2 additions & 3 deletions cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,9 +639,8 @@ TEST_F(KVCacheManagerTest, FindBlocksInReuseTreeByBlockKeysTest)
inputTokens->pop_back();
BlockKey fullKey{*inputTokens};
auto const foundFull = kvCacheManager.findBlocksInReuseTreeByBlockKey(fullKey, maxAttentionWindow);
ASSERT_TRUE(foundFull.has_value());
ASSERT_NE(foundFull.value(), nullptr);
auto const& lastBlock = foundFull.value();
ASSERT_NE(foundFull, nullptr);
auto const& lastBlock = foundFull;

// Check the chain back to previous blocks
auto const prev2 = lastBlock->getPrevBlock();
Expand Down
4 changes: 3 additions & 1 deletion jenkins/Build.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,19 @@ def BUILD_CONFIGS = [
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON",
(TARNAME) : "TensorRT-LLM-GH200-CU12.tar.gz",
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA
],
(CONFIG_LINUX_AARCH64_PYBIND): [
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl",
(TARNAME) : "pybind-TensorRT-LLM-GH200.tar.gz",
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA
],
(CONFIG_LINUX_AARCH64_LLVM) : [
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_CUDA_HOST_COMPILER=clang -DCMAKE_LINKER_TYPE=LLD",
(TARNAME) : "llvm-TensorRT-LLM-GH200.tar.gz",
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
(BUILD_JOBS_FOR_CONFIG): "6", // TODO: Remove after fix the build OOM issue on SBSA
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA
],
]

Expand Down