Skip to content

Conversation

@qsang-nv
Copy link
Collaborator

@qsang-nv qsang-nv commented Oct 22, 2025

Summary by CodeRabbit

  • New Features

    • Added support for new KV cache layout option for memory management
    • Enabled CUDA 13.0+ memory prefetching with location-aware operations
  • Refactor

    • Updated internal KV cache handling with conditional layout paths
    • Adjusted method signatures to accommodate multiple KV cache configurations
  • Chores

    • Improved CUDA version compatibility for device attribute queries
    • Updated memory clock rate computation logic for newer CUDA versions

Description

Add vllm kv layout for xqa mla kernel. Also add support in xqa unittest for cuda version >= 13.0.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@qsang-nv
Copy link
Collaborator Author

/bot run --disable_fail_fast

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 22, 2025

📝 Walkthrough

Walkthrough

Changes introduce support for a new KV cache layout (PAGED_KV_CACHE_LAYOUT == 1) with separate K/V cache pools, update KV cache indexing and tensor map creation logic, and add CUDA 13.0+ support for memory prefetching and clock rate queries across kernel implementations and test infrastructure.

Changes

Cohort / File(s) Summary
KV Cache Layout Support
cpp/kernels/xqa/mha.h, cpp/kernels/xqa/mla_sm120.cu
Added conditional parameters to launchMLA function signature: introduces kCacheVLLM and vCacheVLLM pointers when PAGED_KV_CACHE_LAYOUT == 1. Modified KV cache indexing in KVTilePartLoader to branch on layout type for baseOffset calculation and page loading. Updated tensor map creation to use separate K/V caches for layout 1 vs. pooled cache for layout 0.
CUDA 13.0+ Memory Support
cpp/kernels/xqa/test/test.cpp
Added CUDA 13.0+ path for memory prefetching using cudaMemLocation and cudaMemPrefetchAsync with location parameter. Added conditional memory clock rate retrieval via cudaDevAttrMemoryClockRate for bandwidth computation. Includes conditional compilation blocks to exclude layout-specific code paths based on PAGED_KV_CACHE_LAYOUT.
CUDA 13.0+ Clock Rate Query
cpp/kernels/xqa/test/warmup.cu
Added CUDA version-gated path to compute nbCycles using device clock rate from cudaDeviceGetAttribute (CUDA >= 13000); preserves existing prop.clockRate-based behavior for earlier CUDA versions.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as launchMLA Caller
    participant Launcher as launchMLA
    participant Loader as KVTilePartLoader
    participant TensorMap as makeTensorMapForPagedKVCache

    rect rgb(240, 248, 255)
    Note over Launcher: PAGED_KV_CACHE_LAYOUT == 1 Path (VLLM)
    Caller->>Launcher: kCacheVLLM, vCacheVLLM, kvCachePageList
    Launcher->>Launcher: Construct KVCacheList from kCacheVLLM/vCacheVLLM
    Launcher->>TensorMap: Create tensor maps with kCacheVLLM/vCacheVLLM
    TensorMap-->>Launcher: K/V tensor maps
    Launcher->>Loader: Initialize with layout 1 baseOffset<br/>(idxReq * maxNbPagesPerSeq)
    Loader->>Loader: Load pages with layout 1 indexing
    end

    rect rgb(240, 255, 240)
    Note over Launcher: Legacy Path (Layout 0 or Pool)
    Caller->>Launcher: pool, kvCachePageList
    Launcher->>Launcher: Construct KVCacheList from pool
    Launcher->>TensorMap: Create tensor maps with pool
    TensorMap-->>Launcher: K/V tensor maps
    Launcher->>Loader: Initialize with layout 0 baseOffset<br/>(formula with beamWidth, 2)
    Loader->>Loader: Load pages with layout 0 indexing
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

The changes span multiple files and introduce conditional logic for two distinct KV cache layouts plus CUDA version-specific paths, requiring verification of indexing formulas and page addressing across both layout branches. However, changes follow a consistent pattern of adding parallel conditional paths rather than heterogeneous modifications, and test file changes are localized to specific CUDA version blocks.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "[None][feat] Add vLLM KV Pool support for XQA mla kernel" clearly and specifically describes the primary change in the changeset. According to the raw summary, the changes introduce support for a new KV cache layout (PAGED_KV_CACHE_LAYOUT == 1) with separate K/V cache pools (kCacheVLLM and vCacheVLLM) in the XQA MLA kernel. The title is concise, uses the correct format as specified in the template, and accurately reflects the main purpose of the changeset without being vague or overly broad.
Description Check ✅ Passed The PR description provides an explanation of the changes ("Add vllm kv layout for xqa mla kernel. Also add support in xqa unittest for cuda version >= 13.0.") which directly corresponds to the modifications shown in the raw summary. The title format follows the template requirements with "[None][feat]" and a descriptive summary. However, the Test Coverage section is not filled out—the placeholder comment is present but no specific tests are listed, which is a required section in the template. Despite this gap, the core description section is present and on-topic, and the author has indicated completion of the PR checklist by checking the final checkbox.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
cpp/kernels/xqa/test/warmup.cu (1)

7-10: Warmup loop condition is inverted; the loop exits immediately.

Busy‑wait should spin until clock64() reaches tic + cycles.

Apply:

-    while (tic + cycles < clock64())
+    while (clock64() < tic + cycles)
     {
     }
cpp/kernels/xqa/test/test.cpp (1)

569-587: Build break for layout 1 when USE_INPUT_KV==1: wrong page list shape and cache array used.

This block indexes pageList[i][0][kv][...] and writes cacheHeads[...], but with layout 1 we have pageList[batch][page] and separate cacheKHeads/cacheVHeads. This will not compile when USE_INPUT_KV && USE_PAGED_KV_CACHE && PAGED_KV_CACHE_LAYOUT==1.

Apply:

 #if USE_INPUT_KV
@@
-        for (int kv = 0; kv < 2; kv++)
-        {
-            for (int j = 0; j < nbKHeads; j++)
-            {
-#if USE_PAGED_KV_CACHE
-                uint32_t const pageIdx = pageList[i][0][kv][pos / tokensPerPage];
-                uint32_t const idxHead = tokensPerPage * (nbKHeads * pageIdx + j) + pos % tokensPerPage;
-#else
-                uint32_t const idxHead = maxSeqLen * (nbKHeads * i + j) + pos;
-#endif
-                cacheHeads[idxHead].fill(CacheElem(128.F));
-            }
-        }
+        for (int kv = 0; kv < 2; kv++)
+        {
+            for (int j = 0; j < nbKHeads; j++)
+            {
+#if USE_PAGED_KV_CACHE
+#if PAGED_KV_CACHE_LAYOUT == 1
+                uint32_t const pageIdx = pageList[i][pos / tokensPerPage];
+                uint32_t const idxHead = pageIdx * tokensPerPage * nbKHeads
+                                       + (pos % tokensPerPage) * nbKHeads + j;
+                auto& cacheRef = (kv == 0) ? cacheKHeads[idxHead] : cacheVHeads[idxHead];
+                cacheRef.fill(CacheElem(128.F));
+#else
+                uint32_t const pageIdx = pageList[i][0][kv][pos / tokensPerPage];
+                uint32_t const idxHead = tokensPerPage * (nbKHeads * pageIdx + j) + pos % tokensPerPage;
+                cacheHeads[idxHead].fill(CacheElem(128.F));
+#endif
+#else
+                uint32_t const idxHead = maxSeqLen * (nbKHeads * i + j) + pos;
+                cacheHeads[idxHead].fill(CacheElem(128.F));
+#endif
+            }
+        }
🧹 Nitpick comments (1)
cpp/kernels/xqa/mla_sm120.cu (1)

1973-1980: Scratch layout depends on grid size; ensure caller‑provided scratch is large enough.

cgaXBuf and partialResults are carved from scratch using nbCgas. Consider asserting/ documenting required scratch size to avoid OOB on large grids.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6cf1c3f and aafb23e.

📒 Files selected for processing (4)
  • cpp/kernels/xqa/mha.h (1 hunks)
  • cpp/kernels/xqa/mla_sm120.cu (6 hunks)
  • cpp/kernels/xqa/test/test.cpp (6 hunks)
  • cpp/kernels/xqa/test/warmup.cu (1 hunks)
🧰 Additional context used
📓 Path-based instructions (7)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/kernels/xqa/test/warmup.cu
  • cpp/kernels/xqa/test/test.cpp
  • cpp/kernels/xqa/mha.h
  • cpp/kernels/xqa/mla_sm120.cu
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/kernels/xqa/test/warmup.cu
  • cpp/kernels/xqa/test/test.cpp
  • cpp/kernels/xqa/mha.h
  • cpp/kernels/xqa/mla_sm120.cu
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • cpp/kernels/xqa/test/warmup.cu
  • cpp/kernels/xqa/test/test.cpp
  • cpp/kernels/xqa/mha.h
  • cpp/kernels/xqa/mla_sm120.cu
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • cpp/kernels/xqa/test/warmup.cu
  • cpp/kernels/xqa/test/test.cpp
  • cpp/kernels/xqa/mha.h
  • cpp/kernels/xqa/mla_sm120.cu
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.

Files:

  • cpp/kernels/xqa/test/test.cpp
  • cpp/kernels/xqa/mha.h
**/*.{h,hpp,hh,hxx}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Document new class interfaces and function prototypes with Doxygen; use //! for single-line and //!< for members.

Files:

  • cpp/kernels/xqa/mha.h
**/*.{h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use include guards named 'TRTLLM_<FILE_NAME_IN_CAPS_WITH_UNDERSCORES>_H' (no leading or trailing underscore; directory names excluded).

Files:

  • cpp/kernels/xqa/mha.h
🧬 Code graph analysis (1)
cpp/kernels/xqa/mla_sm120.cu (1)
cpp/kernels/xqa/tma.h (2)
  • loadAsync (74-129)
  • loadAsync (132-191)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (7)
cpp/kernels/xqa/test/test.cpp (3)

82-97: CUDA 13+ managed prefetch API usage looks correct.

Using cudaMemLocation with cudaMemPrefetchAsync(ptr, size, location, flags=0, stream) is consistent with CUDA 13+. Fallback path retained for older CUDA.

Please confirm our CI toolchains define CUDA_VERSION >= 13000 where this overload exists (e.g., R13+). If some builders use older CUDA headers, guard failures could occur.


235-241: KV page list shape and initialization for layout 1 (VLLM) look consistent.

  • total pages, buffer sizing, and linear init/shuffle align with [batchSize][nbPagesPerSeq] shape.
  • pageListArg pointers match the expected kernel API per layout.

Also applies to: 262-271, 309-336, 339-355


713-717: All launchMLA/launchMHA call sites are consistently updated; no additional action needed.

Verification confirms both call sites (lines 713-717 for launchMLA and lines 759-763 for launchMHA) correctly pass separate K/V cache heads (cacheKHeads.get(), cacheVHeads.get()) under PAGED_KV_CACHE_LAYOUT == 1, falling back to the single pool otherwise. Both match the function signature expectations in mha.h lines 174-176, where layout 1 expects two separate parameters (kCacheVLLM, vCacheVLLM). No inconsistencies found.

cpp/kernels/xqa/mla_sm120.cu (3)

115-119: Correct base offset for layout 1 in KVTilePartLoader.

For VLLM layout, baseOffset = idxReq * maxNbPagesPerSeq matches [batch][page] indexing. Good.


1936-1942: Host launchMLA: tensor map creation matches layout 1 vs pool paths.

Separate K/V tensor maps for layout 1 are correctly used; pool fallback unchanged.

Also applies to: 1962-1971


146-151: No issues found—tensor map dimension orders are correctly matched to TMA loads.

The verification confirms that both layout branches correctly map dimension indices to tensor coordinates:

  • Layout 1: {idxElemBeg, idxHeadGrp, offset, pages[i]} → tensor dims {headElems, nbKHeads, tokensPerPage, pageId}
  • Layout 0: {idxElemBeg, offset, idxHeadGrp, pages[i]} → tensor dims {headElems, tokensPerPage, nbKHeads, pageId}

Both single-page (lines 146–151) and multi-page (lines 160–167) cases are consistent with their respective tensor map encodings in tensorMap.cpp.

cpp/kernels/xqa/mha.h (1)

172-180: All function signatures match between declarations and definitions.

Verified that launchMLA, launchMHA, and launchHopperF8MHA have consistent signatures across all files:

  • launchMLA: mha.h:168 matches mla_sm120.cu:1874
  • launchMHA: mha.h:88 matches mha.cu:2680
  • launchHopperF8MHA: mha.h:128 matches mha_sm90.cu:3225

All three functions correctly use the same conditional structure for PAGED_KV_CACHE_LAYOUT==1 (taking kCacheVLLM and vCacheVLLM parameters) versus Layout 0 (taking pool). The public API extension is properly implemented.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22114 Bot args parsing error: usage: /bot [-h]
{run,kill,skip,submit,reviewers,reuse-pipeline,reuse-review} ...
/bot: error: unrecognized arguments: --disable_fail_fast

@qsang-nv
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22115 [ run ] triggered by Bot. Commit: aafb23e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22115 [ run ] completed with state SUCCESS. Commit: aafb23e
/LLM/main/L0_MergeRequest_PR pipeline #16676 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@qsang-nv qsang-nv requested a review from lowsfer October 22, 2025 05:58
@lowsfer lowsfer merged commit 07edac2 into NVIDIA:main Oct 22, 2025
8 of 9 checks passed
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request Oct 24, 2025
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Signed-off-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 1, 2025
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants