@@ -168,7 +168,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
168168
169169 std::optional<size_t > maxNumTokens = mCacheTransceiverConfig .value ().getMaxTokensInBuffer ();
170170
171- mCacheTransBufferManager = std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens);
171+ mCacheTransBufferManagers .push_back (
172+ std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens));
173+ if (isMLA && cacheManager->isEnableIndexerKCache ())
174+ {
175+ mCacheTransBufferManagers .push_back (
176+ std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens, true ));
177+ }
178+ mCacheTransBufferManagerPtrs .clear ();
179+ mCacheTransBufferManagerPtrs .reserve (mCacheTransBufferManagers .size ());
180+ for (auto & manager : mCacheTransBufferManagers )
181+ {
182+ mCacheTransBufferManagerPtrs .push_back (manager.get ());
183+ }
172184 if (backendType.value () == executor::CacheTransceiverConfig::BackendType::UCX)
173185 {
174186 std::lock_guard<std::mutex> lock (mDllMutex );
@@ -191,7 +203,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
191203 else if (backendType.value () == executor::CacheTransceiverConfig::BackendType::NIXL)
192204 {
193205 mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
194- mCacheTransBufferManager . get () , *mCacheState );
206+ mCacheTransBufferManagerPtrs , *mCacheState );
195207 TLLM_LOG_INFO (" NIXL Connection Manager created" );
196208 }
197209 else if (backendType.value () == executor::CacheTransceiverConfig::BackendType::MPI)
@@ -206,7 +218,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
206218 }
207219
208220 auto makeFormatter = [cacheManager, isMLA, this ]()
209- { return createCacheFormatter (cacheManager, mCacheTransBufferManager . get () , isMLA); };
221+ { return createCacheFormatter (cacheManager, mCacheTransBufferManagerPtrs , isMLA); };
210222
211223 mCacheSender = std::make_unique<CacheSender>(mManager .get (), *mCacheState , worldConfig.getRank (), makeFormatter ());
212224 mCacheReceiver
0 commit comments