Skip to content

Commit 23decaf

Browse files
committed
Have ability to cancel disagg request if KV cache resource are exhausted
Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
1 parent 121140c commit 23decaf

File tree

11 files changed

+169
-14
lines changed

11 files changed

+169
-14
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,16 +1465,19 @@ class CacheTransceiverConfig
14651465
NIXL = 3
14661466
};
14671467
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
1468-
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt);
1468+
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,
1469+
std::optional<int> kvTransferSenderFutureTimeoutMs = std::nullopt);
14691470

14701471
bool operator==(CacheTransceiverConfig const& other) const;
14711472
void setBackendType(std::optional<BackendType> backendType);
14721473
void setMaxTokensInBuffer(std::optional<size_t> maxTokensInBuffer);
14731474
void setKvTransferTimeoutMs(std::optional<int> kvTransferTimeoutMs);
1475+
void setKvTransferSenderFutureTimeoutMs(std::optional<int> kvTransferSenderFutureTimeoutMs);
14741476

1475-
[[nodiscard]] std::optional<int> getKvTransferTimeoutMs() const;
14761477
[[nodiscard]] std::optional<size_t> getMaxTokensInBuffer() const;
14771478
[[nodiscard]] std::optional<BackendType> getBackendType() const;
1479+
[[nodiscard]] std::optional<int> getKvTransferTimeoutMs() const;
1480+
[[nodiscard]] std::optional<int> getKvTransferSenderFutureTimeoutMs() const;
14781481

14791482
private:
14801483
std::optional<BackendType> mBackendType;
@@ -1483,6 +1486,9 @@ class CacheTransceiverConfig
14831486
/// transfer may be degraded.
14841487
std::optional<size_t> mMaxTokensInBuffer;
14851488
std::optional<int> mKvTransferTimeoutMs;
1489+
// @brief Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This
1490+
// allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms
1491+
std::optional<int> mKvTransferSenderFutureTimeoutMs;
14861492
};
14871493

14881494
/// @brief Configuration class for the model executor

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,13 @@ void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm,
419419
void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
420420
{
421421
bool blockAll = !atLeastRequestNum.has_value();
422+
std::optional<int> senderFutureTimeoutMs = std::nullopt;
423+
// If blockAll is true, we want to block and not use a timeout
424+
if (!blockAll && mCacheTransceiverConfig.has_value())
425+
{
426+
senderFutureTimeoutMs = mCacheTransceiverConfig->getKvTransferSenderFutureTimeoutMs();
427+
}
428+
422429
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupTPInDPComm : mGroupTensorParaComm;
423430
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
424431
for (auto&& [request, future] : mSenderFutures)
@@ -476,16 +483,36 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
476483
{
477484
try
478485
{
479-
future.get();
480-
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
486+
// Wait for up to a specified timeout
487+
auto status = future.wait_for(std::chrono::milliseconds(senderFutureTimeoutMs.value_or(0)));
488+
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
489+
{
490+
future.get();
491+
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
492+
it = mSenderFutures.erase(it);
493+
}
494+
else if (status == std::future_status::timeout)
495+
{
496+
TLLM_LOG_WARNING("Timed out waiting for context transfer for request %ld after %d seconds.",
497+
request->mRequestId, senderFutureTimeoutMs.value());
498+
++it;
499+
}
500+
else
501+
{
502+
TLLM_LOG_ERROR(
503+
"Future returned unexpected status for request %ld. Marking as error", request->mRequestId);
504+
505+
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
506+
it = mSenderFutures.erase(it);
507+
}
481508
}
482509
catch (std::exception const& e)
483510
{
484511
TLLM_LOG_ERROR(
485512
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
486513
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
514+
it = mSenderFutures.erase(it);
487515
}
488-
it = mSenderFutures.erase(it);
489516
}
490517
else
491518
{

cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
namespace tensorrt_llm::executor
2222
{
2323

24-
CacheTransceiverConfig::CacheTransceiverConfig(
25-
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
24+
CacheTransceiverConfig::CacheTransceiverConfig(std::optional<BackendType> backendType,
25+
std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs,
26+
std::optional<int> kvTransferSenderFutureTimeoutMs)
2627
: mBackendType(backendType)
2728
, mMaxTokensInBuffer(maxNumTokens)
2829
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
30+
, mKvTransferSenderFutureTimeoutMs(kvTransferSenderFutureTimeoutMs)
2931
{
3032
}
3133

@@ -54,6 +56,15 @@ void CacheTransceiverConfig::setKvTransferTimeoutMs(std::optional<int> kvTransfe
5456
mKvTransferTimeoutMs = kvTransferTimeoutMs;
5557
}
5658

59+
void CacheTransceiverConfig::setKvTransferSenderFutureTimeoutMs(std::optional<int> kvTransferSenderFutureTimeoutMs)
60+
{
61+
if (kvTransferSenderFutureTimeoutMs.has_value() && kvTransferSenderFutureTimeoutMs.value() <= 0)
62+
{
63+
TLLM_THROW("kvTransferSenderFutureTimeoutMs must be positive");
64+
}
65+
mKvTransferSenderFutureTimeoutMs = kvTransferSenderFutureTimeoutMs;
66+
}
67+
5768
std::optional<CacheTransceiverConfig::BackendType> CacheTransceiverConfig::getBackendType() const
5869
{
5970
return mBackendType;
@@ -69,4 +80,8 @@ std::optional<int> CacheTransceiverConfig::getKvTransferTimeoutMs() const
6980
return mKvTransferTimeoutMs;
7081
}
7182

83+
std::optional<int> CacheTransceiverConfig::getKvTransferSenderFutureTimeoutMs() const
84+
{
85+
return mKvTransferSenderFutureTimeoutMs;
86+
}
7287
} // namespace tensorrt_llm::executor

cpp/tensorrt_llm/executor/serialization.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1290,20 +1290,26 @@ CacheTransceiverConfig Serialization::deserializeCacheTransceiverConfig(std::ist
12901290
{
12911291
auto backendType = su::deserialize<std::optional<CacheTransceiverConfig::BackendType>>(is);
12921292
auto maxTokensInBuffer = su::deserialize<std::optional<size_t>>(is);
1293-
return CacheTransceiverConfig{backendType, maxTokensInBuffer};
1293+
auto kvTransferTimeoutMs = su::deserialize<std::optional<int>>(is);
1294+
auto kvTransferSenderFutureTimeoutMs = su::deserialize<std::optional<int>>(is);
1295+
return CacheTransceiverConfig{backendType, maxTokensInBuffer, kvTransferTimeoutMs, kvTransferSenderFutureTimeoutMs};
12941296
}
12951297

12961298
void Serialization::serialize(CacheTransceiverConfig const& cacheTransceiverConfig, std::ostream& os)
12971299
{
12981300
su::serialize(cacheTransceiverConfig.getBackendType(), os);
12991301
su::serialize(cacheTransceiverConfig.getMaxTokensInBuffer(), os);
1302+
su::serialize(cacheTransceiverConfig.getKvTransferTimeoutMs(), os);
1303+
su::serialize(cacheTransceiverConfig.getKvTransferSenderFutureTimeoutMs(), os);
13001304
}
13011305

13021306
size_t Serialization::serializedSize(CacheTransceiverConfig const& cacheTransceiverConfig)
13031307
{
13041308
size_t totalSize = 0;
13051309
totalSize += su::serializedSize(cacheTransceiverConfig.getBackendType());
13061310
totalSize += su::serializedSize(cacheTransceiverConfig.getMaxTokensInBuffer());
1311+
totalSize += su::serializedSize(cacheTransceiverConfig.getKvTransferTimeoutMs());
1312+
totalSize += su::serializedSize(cacheTransceiverConfig.getKvTransferSenderFutureTimeoutMs());
13071313
return totalSize;
13081314
}
13091315

cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,15 +465,19 @@ void initConfigBindings(nb::module_& m)
465465

466466
nb::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
467467
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
468-
std::optional<int>>(),
468+
std::optional<int>, std::optional<int>>(),
469469
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt,
470-
nb::arg("kv_transfer_timeout_ms") = std::nullopt)
470+
nb::arg("kv_transfer_timeout_ms") = std::nullopt,
471+
nb::arg("kv_transfer_sender_future_timeout_ms") = std::nullopt)
471472
.def_prop_rw(
472473
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
473474
.def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
474475
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
475476
.def_prop_rw("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
476477
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
478+
.def_prop_rw("kv_transfer_sender_future_timeout_ms",
479+
&tle::CacheTransceiverConfig::getKvTransferSenderFutureTimeoutMs,
480+
&tle::CacheTransceiverConfig::setKvTransferSenderFutureTimeoutMs)
477481
.def("__getstate__", cacheTransceiverConfigGetstate)
478482
.def("__setstate__", cacheTransceiverConfigSetstate);
479483

cpp/tensorrt_llm/pybind/executor/executorConfig.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,13 +449,17 @@ void initConfigBindings(pybind11::module_& m)
449449
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
450450
std::optional<int>>(),
451451
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt,
452-
py::arg("kv_transfer_timeout_ms") = std::nullopt)
452+
py::arg("kv_transfer_timeout_ms") = std::nullopt,
453+
py::arg("kv_transfer_sender_future_timeout_ms") = std::nullopt)
453454
.def_property(
454455
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
455456
.def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
456457
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
457458
.def_property("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
458459
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
460+
.def_property("kv_transfer_sender_future_timeout_ms",
461+
&tle::CacheTransceiverConfig::getKvTransferSenderFutureTimeoutMs,
462+
&tle::CacheTransceiverConfig::setKvTransferSenderFutureTimeoutMs)
459463
.def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate));
460464

461465
auto executorConfigGetState = [](py::object const& self)

cpp/tests/unit_tests/executor/serializeUtilsTest.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ TEST(SerializeUtilsTest, ExecutorConfig)
789789
texec::GuidedDecodingConfig(
790790
texec::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR, std::initializer_list<std::string>{"eos"}),
791791
std::vector{tensorrt_llm::executor::AdditionalModelOutput{"output_name"}},
792-
texec::CacheTransceiverConfig(std::nullopt, 1024), true, true, true);
792+
texec::CacheTransceiverConfig(std::nullopt, 1024, 100, 1000), true, true, true);
793793
auto executorConfig2 = serializeDeserialize(executorConfig);
794794

795795
EXPECT_EQ(executorConfig.getMaxBeamWidth(), executorConfig2.getMaxBeamWidth());
@@ -1028,10 +1028,13 @@ TEST(SerializeUtilsTest, MethodReturnType)
10281028
TEST(SerializeUtilsTest, CacheTransceiverConfig)
10291029
{
10301030
texec::CacheTransceiverConfig cacheTransceiverConfig(
1031-
tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, 1024);
1031+
tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, 1024, 100, 1000);
10321032
auto cacheTransceiverConfig2 = serializeDeserialize(cacheTransceiverConfig);
10331033
EXPECT_EQ(cacheTransceiverConfig.getBackendType(), cacheTransceiverConfig2.getBackendType());
10341034
EXPECT_EQ(cacheTransceiverConfig.getMaxTokensInBuffer(), cacheTransceiverConfig2.getMaxTokensInBuffer());
1035+
EXPECT_EQ(cacheTransceiverConfig.getKvTransferTimeoutMs(), cacheTransceiverConfig2.getKvTransferTimeoutMs());
1036+
EXPECT_EQ(cacheTransceiverConfig.getKvTransferSenderFutureTimeoutMs(),
1037+
cacheTransceiverConfig2.getKvTransferSenderFutureTimeoutMs());
10351038
}
10361039

10371040
TEST(SerializeUtilsTest, BlockKeyBasic)

examples/disaggregated/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ cache_transceiver_config:
1919
# KV cache transfer timeout in milliseconds
2020
# For requests, if they do not send/receive the KV cache in time they are cancelled and cleaned up
2121
kv_transfer_timeout_ms: <int>
22+
# Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms
23+
kv_transfer_sender_future_timeout_ms: <int>
2224
```
2325
2426
The following is an example, consisting of the `ctx_extra-llm-api-config.yaml` and `gen_extra-llm-api-config.yaml` files needed in the sections below.

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(self, mapping: Mapping, dist: Distributed,
111111
pp_layer_num_per_pp_rank = dist.pp_allgather(pp_layer_num)
112112

113113
self.kv_transfer_timeout_ms = cache_transceiver_config.kv_transfer_timeout_ms
114+
self.kv_transfer_sender_future_timeout_ms = cache_transceiver_config.kv_transfer_sender_future_timeout_ms
114115
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
115116
total_num_kv_heads_per_layer, head_dim,
116117
tokens_per_block, world_config,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1563,11 +1563,20 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
15631563
"Timeout in milliseconds for KV cache transfer. Requests exceeding this timeout will be cancelled."
15641564
)
15651565

1566+
kv_transfer_sender_future_timeout_ms: Optional[int] = Field(
1567+
default=1000,
1568+
gt=0,
1569+
description=
1570+
"Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms"
1571+
)
1572+
15661573
def _to_pybind(self):
15671574
return _CacheTransceiverConfig(
15681575
backend=_CacheTransceiverBackendType.from_string(self.backend),
15691576
max_tokens_in_buffer=self.max_tokens_in_buffer,
1570-
kv_transfer_timeout_ms=self.kv_transfer_timeout_ms)
1577+
kv_transfer_timeout_ms=self.kv_transfer_timeout_ms,
1578+
kv_transfer_sender_future_timeout_ms=self.
1579+
kv_transfer_sender_future_timeout_ms)
15711580

15721581

15731582
@dataclass

0 commit comments

Comments
 (0)