Skip to content

Commit 7210604

Browse files
committed
GDS_MT backend support for LoopbackAgent
1 parent 4a1b091 commit 7210604

File tree

4 files changed

+15
-6
lines changed

4 files changed

+15
-6
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ class BlockManager
872872
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
873873
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
874874
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false,
875-
bool enablePartialReuse = true, bool copyOnPartialReuse = true);
875+
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool multiThreadReuse = false);
876876

877877
//! \brief Calculate the proportional share each window size receives of the total memory pool
878878
//! \details Example: (uniqueWindowSizeToLayers={1024: [1], 4096: [0, 4, 5], 8192: [2, 3]})

cpp/include/tensorrt_llm/executor/transferAgent.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ struct BaseAgentConfig
270270
{
271271
std::string mName;
272272
bool useProgThread;
273+
bool multiThread;
273274
};
274275

275276
class BaseTransferAgent

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,15 +407,15 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
407407
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
408408
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
409409
std::shared_ptr<KVCacheEventManager> eventManager, bool enableHashKey, bool enablePartialReuse,
410-
bool copyOnPartialReuse)
410+
bool copyOnPartialReuse, bool multiThreadReuse)
411411
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
412412
, mTokensPerBlock{tokensPerBlock}
413413
, mEventManager{std::move(eventManager)}
414414
, mStream{stream}
415415
, mCacheType{cacheType}
416416
{
417417
mAgentName = std::string("GDSAgent");
418-
BaseAgentConfig config{mAgentName, true};
418+
BaseAgentConfig config{mAgentName, true, multiThreadReuse};
419419
mLoopbackAgent = makeLoopbackAgent("nixl", &config);
420420

421421
auto const uniqueWindowSizeToLayers

cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,11 +501,19 @@ NixlLoopbackAgent::NixlLoopbackAgent(BaseAgentConfig const& config)
501501
init["batch_limit"] = std::to_string(128);
502502
init["max_request_size"] = std::to_string(16 * 1024 * 1024);
503503

504-
status = mRawAgent->createBackend("GDS", init, backend);
505-
if (status != NIXL_SUCCESS || !backend)
504+
if (config.multiThread)
506505
{
507-
TLLM_THROW("Failed to create NIXL backend, status = %d", status);
506+
status = mRawAgent->createBackend("GDS_MT", init, backend);
507+
if (status != NIXL_SUCCESS || !backend)
508+
TLLM_THROW("Failed to create NIXL GDS_MT backend, status = %d", status);
508509
}
510+
else
511+
{
512+
status = mRawAgent->createBackend("GDS", init, backend);
513+
if (status != NIXL_SUCCESS || !backend)
514+
TLLM_THROW("Failed to create NIXL GDS backend, status = %d", status);
515+
}
516+
509517
TLLM_LOG_INFO("NixlLoopbackAgent::NixlLoopbackAgent mAddress: %s", mAddress.c_str());
510518
}
511519

0 commit comments

Comments
 (0)