@@ -112,7 +112,11 @@ __device__ inline KVTilePartLoader::KVTilePartLoader(
112112 , tensorMap{tensorMap}
113113#if USE_PAGED_KV_CACHE
114114 , nbPages{nbPages}
115+ #if PAGED_KV_CACHE_LAYOUT == 1
116+ , baseOffset{idxReq * cacheList.maxNbPagesPerSeq }
117+ #else
115118 , baseOffset{((idxReq * beamWidth) * 2 ) * cacheList.maxNbPagesPerSeq }
119+ #endif
116120#else
117121 , baseOffset{(idxReq * beamWidth) * 2 }
118122#endif
@@ -139,7 +143,11 @@ __device__ inline void KVTilePartLoader::loadData(Array2D<LdGrain, nbTokens, gra
139143 uint32_t const offset = nbTokens * (idxTile % exactDiv (tokensPerPage, nbTokens));
140144 if (warpElectSync ())
141145 {
146+ #if PAGED_KV_CACHE_LAYOUT == 1
147+ tma::loadAsync (&dst, tensorMap, DimsLE<4 >{idxElemBeg, idxHeadGrp, offset, (uint32_t ) pages[0 ]}, bar);
148+ #else
142149 tma::loadAsync (&dst, tensorMap, DimsLE<4 >{idxElemBeg, offset, idxHeadGrp, (uint32_t ) pages[0 ]}, bar);
150+ #endif
143151 }
144152 }
145153 else
@@ -149,8 +157,13 @@ __device__ inline void KVTilePartLoader::loadData(Array2D<LdGrain, nbTokens, gra
149157 {
150158 if (warpElectSync ())
151159 {
160+ #if PAGED_KV_CACHE_LAYOUT == 1
161+ tma::loadAsync (&dst (tokensPerPage * i, 0 ), tensorMap,
162+ DimsLE<4 >{idxElemBeg, idxHeadGrp, 0 , (uint32_t ) pages[i]}, bar);
163+ #else
152164 tma::loadAsync (&dst (tokensPerPage * i, 0 ), tensorMap,
153165 DimsLE<4 >{idxElemBeg, 0 , idxHeadGrp, (uint32_t ) pages[i]}, bar);
166+ #endif
154167 }
155168 }
156169 }
@@ -1859,13 +1872,18 @@ CUtensorMap makeTensorMapForQ(
18591872#endif // IS_MLA
18601873
18611874void launchMLA (cudaDeviceProp const & prop,
1862- uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
1875+ uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
18631876 float qScale, OutputHead* output, InputHead const * q,
1864- float * attentionSinks, // [headGrpSize], not supported.
18651877#if USE_PAGED_KV_CACHE
1866- GMemCacheHead* pool, // global pool of pages
1878+ #if PAGED_KV_CACHE_LAYOUT == 1
1879+ GMemCacheHead* kCacheVLLM , // K cache pool for VLLM layout
1880+ GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout
1881+ #else
1882+ GMemCacheHead* pool, // global pool of pages
1883+ #endif
18671884 KVCachePageIndex const *
1868- kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
1885+ kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
1886+ // [batchSize][maxNbPagesPerSeq] (Layout 1)
18691887#else
18701888 GMemKVCacheHead* kvCacheData,
18711889#endif
@@ -1916,7 +1934,11 @@ void launchMLA(cudaDeviceProp const& prop,
19161934 auto const launchCfg = makeLaunchConfig (dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0 );
19171935#if USE_PAGED_KV_CACHE
19181936 uint32_t const maxNbPagesPerSeq = exactDiv (maxSeqLen, tokensPerPage);
1937+ #if PAGED_KV_CACHE_LAYOUT == 1
1938+ KVCacheList<true > const cacheList{kCacheVLLM , vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq};
1939+ #else
19191940 KVCacheList<true > const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq};
1941+ #endif
19201942 auto const dtype = []
19211943 {
19221944 if (std::is_same_v<CacheElem, half>)
@@ -1936,10 +1958,17 @@ void launchMLA(cudaDeviceProp const& prop,
19361958
19371959 auto const tensorMapQ
19381960 = makeTensorMapForQ (q, dtype, validElemsPerHead, headGrpSize * inputSeqLen * batchSize, partElemsK);
1961+ #if PAGED_KV_CACHE_LAYOUT == 1
1962+ auto const tensorMapK = makeTensorMapForPagedKVCache (
1963+ kCacheVLLM , dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile);
1964+ auto const tensorMapV = makeTensorMapForPagedKVCache (
1965+ vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile);
1966+ #else
19391967 auto const tensorMapK = makeTensorMapForPagedKVCache (
19401968 pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile);
19411969 auto const tensorMapV = makeTensorMapForPagedKVCache (
19421970 pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile);
1971+ #endif
19431972
19441973 uint32_t const nbCgas = exactDiv (dimGrid.x , 4 ) * dimGrid.y * dimGrid.z ;
19451974 auto const cgaXBuf = static_cast <Vec<CgaXBuffer, nbProducerCtasPerCga>*>(scratch);
0 commit comments