Skip to content

Commit aafb23e

Browse files
committed
add vllm kv layout for xqa mla
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
1 parent 50d4e5b commit aafb23e

File tree

4 files changed

+80
-5
lines changed

4 files changed

+80
-5
lines changed

cpp/kernels/xqa/mha.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,14 @@ void launchMLA(cudaDeviceProp const& prop,
169169
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
170170
float qScale, OutputHead* output, InputHead const* q,
171171
#if USE_PAGED_KV_CACHE
172+
#if PAGED_KV_CACHE_LAYOUT == 1
173+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
174+
#else
172175
GMemCacheHead* pool, // global pool of pages
176+
#endif
173177
KVCachePageIndex const*
174-
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
178+
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
179+
// [batchSize][maxNbPagesPerSeq] (Layout 1)
175180
#else
176181
GMemKVCacheHead* kvCacheData,
177182
#endif

cpp/kernels/xqa/mla_sm120.cu

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,11 @@ __device__ inline KVTilePartLoader::KVTilePartLoader(
112112
, tensorMap{tensorMap}
113113
#if USE_PAGED_KV_CACHE
114114
, nbPages{nbPages}
115+
#if PAGED_KV_CACHE_LAYOUT == 1
116+
, baseOffset{idxReq * cacheList.maxNbPagesPerSeq}
117+
#else
115118
, baseOffset{((idxReq * beamWidth) * 2) * cacheList.maxNbPagesPerSeq}
119+
#endif
116120
#else
117121
, baseOffset{(idxReq * beamWidth) * 2}
118122
#endif
@@ -139,7 +143,11 @@ __device__ inline void KVTilePartLoader::loadData(Array2D<LdGrain, nbTokens, gra
139143
uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens));
140144
if (warpElectSync())
141145
{
146+
#if PAGED_KV_CACHE_LAYOUT == 1
147+
tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, idxHeadGrp, offset, (uint32_t) pages[0]}, bar);
148+
#else
142149
tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, offset, idxHeadGrp, (uint32_t) pages[0]}, bar);
150+
#endif
143151
}
144152
}
145153
else
@@ -149,8 +157,13 @@ __device__ inline void KVTilePartLoader::loadData(Array2D<LdGrain, nbTokens, gra
149157
{
150158
if (warpElectSync())
151159
{
160+
#if PAGED_KV_CACHE_LAYOUT == 1
161+
tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap,
162+
DimsLE<4>{idxElemBeg, idxHeadGrp, 0, (uint32_t) pages[i]}, bar);
163+
#else
152164
tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap,
153165
DimsLE<4>{idxElemBeg, 0, idxHeadGrp, (uint32_t) pages[i]}, bar);
166+
#endif
154167
}
155168
}
156169
}
@@ -1859,13 +1872,18 @@ CUtensorMap makeTensorMapForQ(
18591872
#endif // IS_MLA
18601873

18611874
void launchMLA(cudaDeviceProp const& prop,
1862-
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
1875+
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
18631876
float qScale, OutputHead* output, InputHead const* q,
1864-
float* attentionSinks, // [headGrpSize], not supported.
18651877
#if USE_PAGED_KV_CACHE
1866-
GMemCacheHead* pool, // global pool of pages
1878+
#if PAGED_KV_CACHE_LAYOUT == 1
1879+
GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout
1880+
GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout
1881+
#else
1882+
GMemCacheHead* pool, // global pool of pages
1883+
#endif
18671884
KVCachePageIndex const*
1868-
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
1885+
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
1886+
// [batchSize][maxNbPagesPerSeq] (Layout 1)
18691887
#else
18701888
GMemKVCacheHead* kvCacheData,
18711889
#endif
@@ -1916,7 +1934,11 @@ void launchMLA(cudaDeviceProp const& prop,
19161934
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
19171935
#if USE_PAGED_KV_CACHE
19181936
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
1937+
#if PAGED_KV_CACHE_LAYOUT == 1
1938+
KVCacheList<true> const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq};
1939+
#else
19191940
KVCacheList<true> const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq};
1941+
#endif
19201942
auto const dtype = []
19211943
{
19221944
if (std::is_same_v<CacheElem, half>)
@@ -1936,10 +1958,17 @@ void launchMLA(cudaDeviceProp const& prop,
19361958

19371959
auto const tensorMapQ
19381960
= makeTensorMapForQ(q, dtype, validElemsPerHead, headGrpSize * inputSeqLen * batchSize, partElemsK);
1961+
#if PAGED_KV_CACHE_LAYOUT == 1
1962+
auto const tensorMapK = makeTensorMapForPagedKVCache(
1963+
kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile);
1964+
auto const tensorMapV = makeTensorMapForPagedKVCache(
1965+
vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile);
1966+
#else
19391967
auto const tensorMapK = makeTensorMapForPagedKVCache(
19401968
pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile);
19411969
auto const tensorMapV = makeTensorMapForPagedKVCache(
19421970
pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile);
1971+
#endif
19431972

19441973
uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z;
19451974
auto const cgaXBuf = static_cast<Vec<CgaXBuffer, nbProducerCtasPerCga>*>(scratch);

cpp/kernels/xqa/test/test.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,22 @@ class ManagedMemBuf
7979
{
8080
if (!isTracing)
8181
{
82+
#if CUDA_VERSION >= 13000
83+
cudaMemLocation location;
84+
if (dstDevice == cudaCpuDeviceId)
85+
{
86+
location.type = cudaMemLocationTypeHost;
87+
location.id = 0;
88+
}
89+
else
90+
{
91+
location.type = cudaMemLocationTypeDevice;
92+
location.id = dstDevice;
93+
}
94+
checkCuda(cudaMemPrefetchAsync(get(), sizeof(T) * size(), location, 0, stream));
95+
#else
8296
checkCuda(cudaMemPrefetchAsync(get(), sizeof(T) * size(), dstDevice, stream));
97+
#endif
8398
}
8499
}
85100

@@ -507,6 +522,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
507522
#endif
508523
#if IS_MLA
509524
#if USE_PAGED_KV_CACHE
525+
#if PAGED_KV_CACHE_LAYOUT == 1
526+
// VLLM format: K and V share the same pageList, no copy needed
527+
#else
510528
for (uint32_t idxReq = 0; idxReq < batchSize; idxReq++)
511529
{
512530
for (uint32_t idxBeam = 0; idxBeam < beamWidth; idxBeam++)
@@ -517,6 +535,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
517535
}
518536
}
519537
}
538+
#endif
520539
#else
521540
static_assert(false, "not implemented");
522541
#endif
@@ -691,7 +710,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
691710
#else
692711
&output[0][0][0], &qHeads[0][0][0],
693712
#endif
713+
#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE
714+
cacheKHeads.get(), cacheVHeads.get(),
715+
#else
694716
cacheHeads.get(),
717+
#endif
695718
#if USE_PAGED_KV_CACHE
696719
pageListArg,
697720
#endif
@@ -790,7 +813,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
790813
float ms;
791814
checkCuda(cudaEventElapsedTime(&ms, tic, toc));
792815
ms /= nbIters;
816+
#if CUDA_VERSION >= 13000
817+
int memoryClockRateKHz;
818+
checkCuda(cudaDeviceGetAttribute(&memoryClockRateKHz, cudaDevAttrMemoryClockRate, device));
819+
float const bandwidth = 2.f * prop.memoryBusWidth * memoryClockRateKHz * 1000 / 8;
820+
#else
793821
float const bandwidth = 2.f * prop.memoryBusWidth * prop.memoryClockRate * 1000 / 8;
822+
#endif
794823
#if BEAM_WIDTH == 1
795824
size_t nbLoadedCacheTokens = seqLen * beamWidth * batchSize;
796825
#else
@@ -819,7 +848,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
819848
{
820849
printf("done\n");
821850
printf("time: %f ms\n", ms);
851+
#if CUDA_VERSION >= 13000
852+
printf("mem bus width = %d\nmem clock rate = %d\n", prop.memoryBusWidth, memoryClockRateKHz);
853+
#else
822854
printf("mem bus width = %d\nmem clock rate = %d\n", prop.memoryBusWidth, prop.memoryClockRate);
855+
#endif
823856
printf("bandwidth = %e\n", (float) bandwidth);
824857
printf("traffic=%e\n", (float) totalTraffic);
825858
}

cpp/kernels/xqa/test/warmup.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@ __global__ void kernel_warmup(uint64_t cycles)
1212

1313
void warmup(cudaDeviceProp const& prop, float ms, cudaStream_t stream = nullptr)
1414
{
15+
#if CUDA_VERSION >= 13000
16+
int device;
17+
checkCuda(cudaGetDevice(&device));
18+
int clockRateKHz;
19+
checkCuda(cudaDeviceGetAttribute(&clockRateKHz, cudaDevAttrClockRate, device));
20+
uint64_t const nbCycles = std::round(clockRateKHz * ms); // clockRate is in kHz
21+
#else
1522
uint64_t const nbCycles = std::round(prop.clockRate * ms); // clockRate is in kHz
23+
#endif
1624
kernel_warmup<<<16, 128, 0, stream>>>(nbCycles);
1725
checkCuda(cudaGetLastError());
1826
}

0 commit comments

Comments
 (0)