Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/nanobind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ set(SRCS
executor/executorConfig.cpp
executor/request.cpp
runtime/bindings.cpp
testing/modelSpecBinding.cpp
runtime/hostfunc.cpp
runtime/moeBindings.cpp
testing/modelSpecBinding.cpp
userbuffers/bindings.cpp
thop/bindings.cpp
../runtime/ipcNvlsMemory.cu
Expand Down
5 changes: 3 additions & 2 deletions cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
.def("name", [](AssignReqSeqSlots const&) { return AssignReqSeqSlots::name; });

nb::class_<AllocateKvCache>(m, AllocateKvCache::name)
.def(nb::init<>())
.def(nb::init<>(), nb::call_guard<nb::gil_scoped_release>())
.def("__call__", &AllocateKvCache::operator(), nb::arg("kv_cache_manager"), nb::arg("context_requests"),
nb::arg("generation_requests"), nb::arg("model_config"), nb::arg("cross_kv_cache_manager") = std::nullopt)
nb::arg("generation_requests"), nb::arg("model_config"), nb::arg("cross_kv_cache_manager") = std::nullopt,
nb::call_guard<nb::gil_scoped_release>())
.def("name", [](AllocateKvCache const&) { return AllocateKvCache::name; });

nb::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
Expand Down
119 changes: 73 additions & 46 deletions cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,23 +335,28 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, nb::arg("config"),
nb::arg("is_cross_attention"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"),
nb::arg("window_size_to_layers"), nb::arg("allotted_primary_mem_bytes"),
nb::arg("allotted_secondary_mem_bytes"), nb::arg("extra_cost_memory"), nb::arg("kv_factor"))
.def("allocate_pools", &BaseKVCacheManager::allocatePools)
.def("release_pools", &BaseKVCacheManager::releasePools)
.def("start_scheduling", &BaseKVCacheManager::startScheduling)
nb::arg("allotted_secondary_mem_bytes"), nb::arg("extra_cost_memory"), nb::arg("kv_factor"),
nb::call_guard<nb::gil_scoped_release>())
.def("allocate_pools", &BaseKVCacheManager::allocatePools, nb::call_guard<nb::gil_scoped_release>())
.def("release_pools", &BaseKVCacheManager::releasePools, nb::call_guard<nb::gil_scoped_release>())
.def("start_scheduling", &BaseKVCacheManager::startScheduling, nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro("tokens_per_block", &BaseKVCacheManager::getTokensPerBlock)
.def_prop_ro("max_num_blocks", &BaseKVCacheManager::getMaxNumBlocks)
.def_prop_ro("num_pools", &BaseKVCacheManager::getNumPools)
.def("get_kv_cache_stats", &BaseKVCacheManager::getKvCacheStats)
.def("get_kv_cache_stats", &BaseKVCacheManager::getKvCacheStats, nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro("max_blocks_per_seq",
[](tbk::BaseKVCacheManager& self) { return self.getOffsetTableDimensions().maxBlocksPerSeq; })
.def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep)
.def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion)
.def("add_token", &BaseKVCacheManager::addToken)
.def("add_sequence", &BaseKVCacheManager::addSequence)
.def("remove_sequence", &BaseKVCacheManager::removeSequence)
.def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence)
.def("get_block_pool_pointers",
.def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep,
nb::call_guard<nb::gil_scoped_release>())
.def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion,
nb::call_guard<nb::gil_scoped_release>())
.def("add_token", &BaseKVCacheManager::addToken, nb::call_guard<nb::gil_scoped_release>())
.def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard<nb::gil_scoped_release>())
.def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard<nb::gil_scoped_release>())
.def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence,
nb::call_guard<nb::gil_scoped_release>())
.def(
"get_block_pool_pointers",
[](tbk::BaseKVCacheManager& self)
{
std::optional<at::Tensor> block_pool_pointers{std::nullopt};
Expand All @@ -362,8 +367,10 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
block_pool_pointers = tr::Torch::tensor(_tensor);
}
return block_pool_pointers;
})
.def("get_block_scale_pool_pointers",
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"get_block_scale_pool_pointers",
[](tbk::BaseKVCacheManager& self)
{
std::optional<at::Tensor> block_scale_pool_pointers{std::nullopt};
Expand All @@ -374,8 +381,10 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
block_scale_pool_pointers = tr::Torch::tensor(_tensor);
}
return block_scale_pool_pointers;
})
.def("get_layer_to_pool_mapping",
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"get_layer_to_pool_mapping",
[](tbk::BaseKVCacheManager& self)
{
std::optional<at::Tensor> layer_to_pool_mapping{std::nullopt};
Expand All @@ -386,33 +395,43 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
layer_to_pool_mapping = tr::Torch::tensor(_tensor);
}
return layer_to_pool_mapping;
})
.def("get_primary_pool_data",
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"get_primary_pool_data",
[](tbk::BaseKVCacheManager& self, SizeType32 layer_idx) -> at::Tensor
{
auto pool = tr::Torch::tensor(self.getPrimaryPool(layer_idx));
auto pool_layer_idx = self.getPoolLayerIdx(layer_idx);
return pool.index({torch::indexing::Slice(), pool_layer_idx});
})
.def("get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); })
.def("get_block_offsets_of_batch",
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); },
nb::call_guard<nb::gil_scoped_release>())
.def(
"get_block_offsets_of_batch",
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize,
SizeType32 beamWidth)
{
auto _output = from_torch(output);
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
self.getBlockOffsetsOfBatch(*(_output.value()), firstBatchSlotIdx, batchSize, beamWidth);
})
.def("copy_block_offsets",
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"copy_block_offsets",
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 outputSlotOffset,
tb::LlmRequest::RequestIdType requestId)
{
auto _output = from_torch(output);
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
auto maxBlockCount = self.copyBlockOffsets(*(_output.value()), outputSlotOffset, requestId);
return maxBlockCount;
})
.def("copy_batch_block_offsets",
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"copy_batch_block_offsets",
[](tbk::BaseKVCacheManager& self, at::Tensor output,
std::vector<tb::LlmRequest::RequestIdType> const& requestIds, SizeType32 const beamWidth,
SizeType32 const offset)
Expand All @@ -423,7 +442,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
{
self.copyBlockOffsets(*(_output.value()), i * beamWidth + offset, requestIds[i]);
}
})
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"get_latest_events",
[](tbk::BaseKVCacheManager& self, std::optional<double> timeout_ms = std::nullopt)
Expand All @@ -434,15 +454,18 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
}
return self.getLatestEvents(std::nullopt);
},
nb::arg("timeout_ms") = std::nullopt)
nb::arg("timeout_ms") = std::nullopt, nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro("enable_block_reuse", &BaseKVCacheManager::isEnableBlockReuse)
.def("rewind_kv_cache", &BaseKVCacheManager::rewindKVCache)
.def("rewind_kv_cache", &BaseKVCacheManager::rewindKVCache, nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro("cross_kv", &BaseKVCacheManager::isCrossKv)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks)
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds)
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds)
.def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds)
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents);
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, nb::call_guard<nb::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard<nb::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
nb::call_guard<nb::gil_scoped_release>())
.def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds,
nb::call_guard<nb::gil_scoped_release>())
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
nb::call_guard<nb::gil_scoped_release>());

nb::bind_vector<CacheBlockIds>(m, "CacheBlockIds")
.def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); })
Expand Down Expand Up @@ -474,35 +497,39 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::arg("enable_block_reuse") = false, nb::arg("onboard_blocks") = true,
nb::arg("cache_type") = tbk::CacheType::kSELF, nb::arg("secondary_offload_min_priority") = std::nullopt,
nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true,
nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr);
nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr,
nb::call_guard<nb::gil_scoped_release>());
}

void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)
{
nb::class_<tb::BasePeftCacheManager, PyBasePeftCacheManager>(m, "BasePeftCacheManager")
.def("add_request_peft", &tb::BasePeftCacheManager::addRequestPeft, nb::arg("request"),
nb::arg("try_gpu_cache") = true)
nb::arg("try_gpu_cache") = true, nb::call_guard<nb::gil_scoped_release>())
.def(
"ensure_batch",
[](tb::BasePeftCacheManager& self, tb::RequestVector const& contextRequests,
tb::RequestVector const& generationRequests, bool resetGpuCache)
{
nb::gil_scoped_release release;
return self.ensureBatch(contextRequests, generationRequests, resetGpuCache);
},
nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("reset_gpu_cache") = false)
.def("reset_device_cache", &tb::BasePeftCacheManager::resetDeviceCache)
{ return self.ensureBatch(contextRequests, generationRequests, resetGpuCache); },
nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("reset_gpu_cache") = false,
nb::call_guard<nb::gil_scoped_release>())
.def(
"reset_device_cache", &tb::BasePeftCacheManager::resetDeviceCache, nb::call_guard<nb::gil_scoped_release>())
.def("mark_request_done", &tb::BasePeftCacheManager::markRequestDone, nb::arg("request"),
nb::arg("pause") = false)
nb::arg("pause") = false, nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro("max_device_pages", &tb::BasePeftCacheManager::getMaxDevicePages)
.def_prop_ro("max_host_pages", &tb::BasePeftCacheManager::getMaxHostPages)
.def("determine_num_pages", &tb::BasePeftCacheManager::determineNumPages, nb::arg("request"))
.def("determine_num_pages", &tb::BasePeftCacheManager::determineNumPages, nb::arg("request"),
nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro("enabled", &tb::BasePeftCacheManager::enabled);

nb::class_<tb::PeftCacheManager, tb::BasePeftCacheManager>(m, "PeftCacheManager")
.def(nb::init<tb::PeftCacheManagerConfig, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
nb::arg("config"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"))
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, nb::arg("taskId"));
nb::arg("config"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"),
nb::call_guard<nb::gil_scoped_release>())
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, nb::arg("taskId"),
nb::call_guard<nb::gil_scoped_release>());

nb::class_<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager").def(nb::init<>());
nb::class_<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager")
.def(nb::init<>(), nb::call_guard<nb::gil_scoped_release>());
}
Loading