Skip to content

Conversation

@MrGeva
Copy link
Collaborator

@MrGeva MrGeva commented Oct 30, 2025

Root Cause
The trtllm_allreduce() function was creating a new AllReduce module instance on every call. During CUDA graph warmup (which involves multiple forward passes), this caused repeated initialization of MNNVL workspace allocation.
When strategy=AUTO, the AllReduce.init() attempts to initialize MNNVLAllReduce, which calls get_allreduce_mnnvl_workspace(). This function performs CPU synchronization operations (torch.cuda.synchronize() and comm.Barrier()) that are incompatible with CUDA graph capture. The code even includes a comment acknowledging this: "CPU barrier since we assume this should not be called in cuda graph".
Repeated module initialization during warmup → repeated barrier calls → hang.
Solution
Implemented module-level caching for AllReduce instances in trtllm.py. The cache key is (rank, world_size, dtype) to handle different tensor configurations.
Key changes:
Added _allreduce_cache dictionary to store AllReduce modules
Modified trtllm_allreduce() to check cache before creating new instances
Each unique configuration creates module only once, before CUDA graph capture
Subsequent calls during warmup reuse the cached module without re-initialization
With caching, the MNNVL workspace allocation (with its CPU synchronization) happens once before CUDA graph capture begins. During warmup and capture, only the cached module's forward pass executes, which contains no blocking CPU operations.

Why this works
This fix benefits all AllReduce strategies (AUTO, NCCL, MIN_LATENCY, MNNVL) as they all allocate workspace during initialization, though NCCL was less affected as it skips MNNVL initialization entirely.

@coderabbitai summary

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@MrGeva MrGeva requested a review from a team as a code owner October 30, 2025 15:53
@MrGeva MrGeva requested a review from nzmora-nvidia October 30, 2025 15:53
@MrGeva
Copy link
Collaborator Author

MrGeva commented Oct 30, 2025

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 30, 2025

📝 Walkthrough

Walkthrough

Introduces module-level caching for AllReduce operators in trtllm_allreduce, replacing per-call creation with lazy initialization and reuse based on configuration key (rank, world_size, dtype), improving CUDA graph compatibility.

Changes

Cohort / File(s) Change Summary
AllReduce Caching Infrastructure
tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
Added module-level _allreduce_cache dictionary for storing AllReduce operator instances. Refactored trtllm_allreduce to lazily construct and cache AllReduce operations per configuration key instead of creating new instances per call. Switched to AllReduceStrategy.AUTO and included tensor dtype in cache key. Behavior of trtllm_allgather remains unchanged.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant trtllm_allreduce
    participant _allreduce_cache
    participant AllReduce

    rect rgb(200, 240, 255)
    note over trtllm_allreduce: Old Behavior (Per-call)
    Caller->>trtllm_allreduce: call with (x, rank, world_size, dtype)
    trtllm_allreduce->>AllReduce: create new instance
    AllReduce-->>trtllm_allreduce: instance
    trtllm_allreduce->>AllReduce: execute
    AllReduce-->>trtllm_allreduce: result
    trtllm_allreduce-->>Caller: return result
    end

    rect rgb(200, 255, 220)
    note over trtllm_allreduce: New Behavior (Cached)
    Caller->>trtllm_allreduce: call with (x, rank, world_size, dtype)
    trtllm_allreduce->>_allreduce_cache: lookup key=(rank, world_size, dtype)
    alt Cache Hit
        _allreduce_cache-->>trtllm_allreduce: cached instance
    else Cache Miss
        trtllm_allreduce->>AllReduce: create new instance
        AllReduce-->>trtllm_allreduce: instance
        trtllm_allreduce->>_allreduce_cache: store instance
    end
    trtllm_allreduce->>AllReduce: execute
    AllReduce-->>trtllm_allreduce: result
    trtllm_allreduce-->>Caller: return result
    end
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

  • Areas requiring attention:
    • Cache key construction and correctness—verify that (rank, world_size, dtype) tuple uniquely identifies configuration
    • Thread safety of cache access, if applicable in multi-threaded contexts
    • Cache initialization and potential memory implications with persistent AllReduce instances
    • Verification that AllReduceStrategy.AUTO is the appropriate replacement for prior strategy choice

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description provides excellent and detailed explanation of the root cause, the specific problem with CUDA graph compatibility, and a comprehensive solution with key changes. However, the description does not follow the required template structure: the detailed content is placed before the template rather than within the designated "Description" section, and more importantly, the "Test Coverage" section is completely empty with no tests listed. While the Root Cause and Solution sections are thorough and the PR Checklist is marked complete, an explicit required template section is missing, which represents a structural incompleteness of the description. The Test Coverage section must be filled out with relevant tests that safeguard the AllReduce caching changes. Additionally, the description content should be moved into the proper template sections (Description and Test Coverage) rather than placed outside the template structure. The author should document which test suites validate that the caching mechanism works correctly and doesn't introduce regressions in AllReduce functionality.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "[#8781][fix] Cache the AllReduce wrapper to avoid re-allocating workspace which caused a hang" directly and specifically describes the main change in the pull request. It follows the repository's required format with the GitHub issue number, type indicator, and a clear summary. The title accurately reflects the core objective: implementing caching for AllReduce instances to fix a hang caused by repeated workspace re-allocation during CUDA graph warmup. The phrasing is concise and sufficiently specific that developers scanning the history would immediately understand the primary change without ambiguity.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9112cff and f85e5ca.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
🧠 Learnings (8)
📓 Common learnings
Learnt from: timlee0212
PR: NVIDIA/TensorRT-LLM#6886
File: tensorrt_llm/_torch/models/modeling_deepseekv3.py:0-0
Timestamp: 2025-08-14T06:36:40.701Z
Learning: In DeepSeek V3 model (tensorrt_llm/_torch/models/modeling_deepseekv3.py), the disagreement between AllReduce.__init__ guard and _compute_mlp_tp_size logic for MNNVL usage is expected by design. The AllReduce component and MLP TP-size computation intentionally use different criteria for MNNVL availability decisions.
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks `is_adapter_in_cpu_cache()` and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: tests/unittest/_torch/multi_gpu/test_nccl_device.py:138-149
Timestamp: 2025-10-13T19:45:03.518Z
Learning: In test_nccl_device.py, the NCCL device AllReduce implementation compares the entire residual tensor on each rank, unlike the UB implementation which compares per-rank chunks. The residual chunking calculations in the test are intentionally overridden to reflect this design difference.
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device implementation, NCCL version 2.28+ requirements are handled at runtime in the nccl_device/config layer rather than with compile-time guards. This allows the allreduceOp to remain version-agnostic and delegates version compatibility validation to the appropriate lower-level components that can gracefully handle unsupported configurations.
📚 Learning: 2025-09-23T15:12:38.312Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device implementation, NCCL version 2.28+ requirements are handled at runtime in the nccl_device/config layer rather than with compile-time guards. This allows the allreduceOp to remain version-agnostic and delegates version compatibility validation to the appropriate lower-level components that can gracefully handle unsupported configurations.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
📚 Learning: 2025-08-14T06:36:40.701Z
Learnt from: timlee0212
PR: NVIDIA/TensorRT-LLM#6886
File: tensorrt_llm/_torch/models/modeling_deepseekv3.py:0-0
Timestamp: 2025-08-14T06:36:40.701Z
Learning: In DeepSeek V3 model (tensorrt_llm/_torch/models/modeling_deepseekv3.py), the disagreement between AllReduce.__init__ guard and _compute_mlp_tp_size logic for MNNVL usage is expected by design. The AllReduce component and MLP TP-size computation intentionally use different criteria for MNNVL availability decisions.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
📚 Learning: 2025-09-23T15:12:38.312Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device allreduce implementation (cpp/tensorrt_llm/thop/allreduceOp.cpp), the goto pattern in runNCCLAllReduceDeviceFusion is intentionally used for future extensibility, allowing multiple switch cases to fallback to the default handler. While not aesthetically ideal, this pattern supports adding more fusion cases later that can reuse the same fallback logic.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
📚 Learning: 2025-10-13T19:45:03.518Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: tests/unittest/_torch/multi_gpu/test_nccl_device.py:138-149
Timestamp: 2025-10-13T19:45:03.518Z
Learning: In test_nccl_device.py, the NCCL device AllReduce implementation compares the entire residual tensor on each rank, unlike the UB implementation which compares per-rank chunks. The residual chunking calculations in the test are intentionally overridden to reflect this design difference.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
📚 Learning: 2025-09-24T03:31:28.908Z
Learnt from: tongyuantongyu
PR: NVIDIA/TensorRT-LLM#7520
File: tensorrt_llm/_torch/pyexecutor/resource_manager.py:605-613
Timestamp: 2025-09-24T03:31:28.908Z
Learning: In TensorRT-LLM Ray orchestrator mode, ProcessGroups are initialized with both Gloo and NCCL backends (e.g., "cuda:nccl,cpu:gloo"), allowing PyTorch distributed to automatically route CPU tensors through Gloo and GPU tensors through NCCL. This eliminates the need for manual device placement when performing allreduce operations on base types.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
📚 Learning: 2025-09-23T14:58:05.372Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:42-49
Timestamp: 2025-09-23T14:58:05.372Z
Learning: In TensorRT-LLM NCCL device kernels (cpp/tensorrt_llm/kernels/nccl_device/), the token partitioning intentionally uses ceil-like distribution (same token_per_rank for all ranks) to ensure all ranks launch the same number of blocks. This is required for optimal NCCL device API barrier performance, even though it may launch extra blocks for non-existent tokens on later ranks. Runtime bounds checking in the kernel (blockID validation) handles the overshoot cases.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
📚 Learning: 2025-09-16T09:30:09.716Z
Learnt from: tongyuantongyu
PR: NVIDIA/TensorRT-LLM#7763
File: cpp/tensorrt_llm/CMakeLists.txt:297-301
Timestamp: 2025-09-16T09:30:09.716Z
Learning: In the TensorRT-LLM project, NCCL libraries are loaded earlier by PyTorch libraries or the bindings library, so the main shared library doesn't need NCCL paths in its RPATH - the libraries will already be available in the process address space when needed.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py (3)
tensorrt_llm/mapping.py (1)
  • Mapping (336-493)
tensorrt_llm/_torch/distributed/ops.py (1)
  • AllReduce (455-617)
tensorrt_llm/functional.py (1)
  • AllReduceStrategy (3876-3885)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Check PR Checklist Resolution
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py (3)

11-14: Well-documented cache declaration.

The module-level cache with clear explanatory comments effectively communicates the purpose and criticality for CUDA graph compatibility. The naming convention and placement within the import try-block are appropriate.


1-1: Verify copyright header includes current year.

Ensure the NVIDIA Apache-2.0 copyright header at the top of this file includes 2025 as per coding guidelines.

As per coding guidelines


21-35: Caching logic is correct and verified as safe.

The cache key (rank, world_size, tensor.dtype) appropriately captures the configuration needed for AllReduce initialization. Verification confirms that AllReduce.forward() is stateless with respect to input tensors—it does not modify any instance attributes or accumulate state across calls. The method safely accepts different tensor inputs and all_reduce_params on each invocation, making instance reuse appropriate for avoiding CUDA graph hangs during capture.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23072 [ run ] triggered by Bot. Commit: f85e5ca

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23072 [ run ] completed with state SUCCESS. Commit: f85e5ca
/LLM/main/L0_MergeRequest_PR pipeline #17398 completed with status: 'FAILURE'

@suyoggupta
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23225 [ run ] triggered by Bot. Commit: c7ab69c

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23225 [ run ] completed with state FAILURE. Commit: c7ab69c
/LLM/main/L0_MergeRequest_PR pipeline #17506 completed with status: 'FAILURE'

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23286 [ run ] triggered by Bot. Commit: c7ab69c

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23286 [ run ] completed with state FAILURE. Commit: c7ab69c
/LLM/main/L0_MergeRequest_PR pipeline #17547 completed with status: 'FAILURE'

@MrGeva
Copy link
Collaborator Author

MrGeva commented Nov 2, 2025

/bot run

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
@MrGeva MrGeva force-pushed the egeva/cache_allreduce branch from c7ab69c to 9d58924 Compare November 2, 2025 10:11
@MrGeva
Copy link
Collaborator Author

MrGeva commented Nov 2, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23287 [ run ] triggered by Bot. Commit: 9d58924

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23287 [ run ] completed with state SUCCESS. Commit: 9d58924
/LLM/main/L0_MergeRequest_PR pipeline #17548 completed with status: 'SUCCESS'

@MrGeva MrGeva merged commit f877823 into NVIDIA:main Nov 2, 2025
5 checks passed
fredricz-20070104 pushed a commit to fredricz-20070104/TensorRT-LLM that referenced this pull request Nov 5, 2025
… workspace which caused a hang (NVIDIA#8803)

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: FredricZ-2007 <226039983+fredricz-20070104@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants