Skip to content

Conversation

@suyoggupta
Copy link
Collaborator

@suyoggupta suyoggupta commented Oct 19, 2025

  • Enable chunked prefill for mamba layers
  • Refactor the custom ops. Move all the mamba ops to mamba sub directory
  • Enable accuracy test for Nemotron-H with chunked prefill enabled

#8272

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
@suyoggupta suyoggupta changed the title Sg/ssm chunked prefill [None][feat] Enable chunked prefill for SSMs in AutoDeploy Oct 19, 2025
@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21789 [ run ] triggered by Bot. Commit: 4de161f

@suyoggupta suyoggupta marked this pull request as ready for review October 19, 2025 08:28
@suyoggupta suyoggupta requested a review from a team as a code owner October 19, 2025 08:28
@suyoggupta suyoggupta requested a review from Fridah-nv October 19, 2025 08:28
@suyoggupta suyoggupta changed the title [None][feat] Enable chunked prefill for SSMs in AutoDeploy [#8272][feat] Enable chunked prefill for SSMs in AutoDeploy Oct 19, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 19, 2025

📝 Walkthrough

Walkthrough

This PR refactors Mamba-related custom operations across multiple backends (Torch, CUDA, Triton), introducing a new use_initial_states flag and renaming internal functions and operators. The prepare_metadata return signature expands from 3-tuple to 4-tuple. The custom_ops package initialization is extended to recursively discover and import nested modules. Related test files and model patches are updated accordingly.

Changes

Cohort / File(s) Summary
Custom ops package initialization
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
Adds recursive module discovery to walk and import all subpackages and nested modules under the package path, ensuring registration side-effects are applied.
Mamba causal convolution backends
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py, tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
CUDA backend extends prepare_metadata to return 4-tuple with use_initial_states flag and threads this through cached conv operations; Torch backend adjusts import paths and adds minor comment.
Torch Mamba operations
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py, tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py
Renames internal functions (_torch_cached_ssm_transform_torch_cached_ssm, _torch_ssm_transform_prefill_torch_ssm_prefill, _torch_ssm_transform_decode_torch_ssm_decode); updates prepare_metadata to return 4-tuple with use_initial_states; updates public op bindings and registration decorators.
Triton Mamba operations
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
Renames _triton_cached_ssm_transform_triton_cached_ssm; adds use_initial_states parameter; updates deferred initialization logic to compute initial_states and chunk indices based on flag; updates public op dispatch targets and prepare_metadata arity.
Model patches
tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py
Updates torch dispatch calls to renamed operators (torch_cached_ssm_transformtorch_cached_ssm, torch_ssm_transformtorch_ssm).
KV cache library
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Removes blank line before return statement; whitespace-only change.
Integration tests
tests/integration/defs/accuracy/test_llm_api_autodeploy.py
Changes compile_model backend from "torch-opt" to "torch-cudagraph" in two tests; comments out skip logic for enable_chunked_prefill in TestNemotronH.
CUDA causal conv unit tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py
Adds use_initial_states boolean parameter to cached conv operator call; updates prepare_metadata test to expect and unpack 4-tuple return value.
Torch causal conv unit tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py
Updates import paths to use mamba namespace for causal conv functions.
Torch Mamba unit tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py
Replaces torch_cached_ssm_transform with torch_cached_ssm; updates prepare_metadata unpacking to reflect 4-tuple return.
Triton Mamba unit tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py
Updates public API calls to renamed operators; adds use_initial_states parameter to triton_cached_ssm call.

Sequence Diagram(s)

sequenceDiagram
    participant Test
    participant PrepareMetadata
    participant CachedOp
    participant Compute
    
    Test->>PrepareMetadata: Call prepare_metadata(input_pos, ...)
    rect rgb(220, 240, 220)
    Note over PrepareMetadata: Compute use_initial_states from<br/>input_pos > 0
    end
    PrepareMetadata-->>Test: (seq_len, seq_start, slot_idx, use_initial_states)
    Test->>CachedOp: Call cached_op(..., use_initial_states)
    
    alt use_initial_states == True
        CachedOp->>Compute: Use initial states from cache
    else
        CachedOp->>Compute: Use runtime-computed states
    end
    
    Compute-->>CachedOp: result
    CachedOp-->>Test: output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~35 minutes

The changes involve a well-defined, consistent pattern across multiple files: function/operator renaming, parameter threading for use_initial_states, and metadata signature expansion. While the scope is broad (11 files), the heterogeneity is manageable—each change follows the same refactoring logic. Key areas requiring focused review are the logic density in the CUDA and Triton backends where the new flag determines conditional behavior, and the prepare_metadata implementations where the flag is derived and propagated.

Possibly related PRs

Suggested labels

AutoDeploy

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The pull request description provided by the author consists entirely of the repository template with all required sections left as empty placeholders. The Description section, Test Coverage section, and PR Checklist items contain no substantive content—only template comments and checkbox markers. No explanation of the issue, solution, affected tests, or verification steps has been provided by the author, making the description essentially empty and off-topic to the actual changes in the PR. The author must fill in the PR description with substantive content. Specifically: provide a brief explanation of the issue being addressed and the solution implemented; list the relevant test cases that validate the changes (based on the raw_summary, tests include test_cuda_causal_conv_cached_op.py, test_torch_mamba_cached_op.py, test_triton_mamba_cached_op.py, and test_llm_api_autodeploy.py); and complete the PR checklist by reviewing and marking appropriate items. The CodeRabbit AI summary hint in the description indicates the author may want to use the "@coderabbitai summary" feature to generate a description if preferred.
Docstring Coverage ⚠️ Warning Docstring coverage is 15.79% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The pull request title "[None][feat] Enable chunked prefill for SSMs in AutoDeploy" follows the required format and directly corresponds to the main changes in the changeset. The changes across multiple files systematically introduce a use_initial_states parameter to SSM (State Space Model) implementations in AutoDeploy, specifically to support chunked prefill functionality. The title is concise, specific, and clearly communicates the primary objective of the PR without vagueness or generic language.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py (1)

125-159: Add use_initial_states parameter to torch_cached_ssm operator.

Verification confirms the operators have asymmetric signatures by design. The triton_cached_ssm operator includes use_initial_states as a required parameter, while torch_cached_ssm lacks it entirely. The test correctly reflects this disparity (lines 127–141 omit it for torch; lines 144–159 include it for triton), but the underlying inconsistency must be resolved.

Update both _torch_cached_ssm and _torch_cached_ssm_fake in tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py to accept use_initial_states: torch.Tensor as a parameter (positioned after slot_idx, before ssm_state_cache, consistent with the triton signature).

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025).

Per guidelines, prepend the 2025 NVIDIA Apache-2.0 header to all source files.

Apply at file top.

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025).

Add the required copyright header.

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025).

Add the required header.

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025).

Add the header at the top of the file.

🧹 Nitpick comments (9)
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1)

12-16: Potential duplicate imports between top-level and recursive discovery.

The initial loop (lines 8-10) imports all top-level modules in __path__, then the recursive walk (lines 14-16) walks all packages including the same top-level modules. This results in duplicate imports and duplicate entries in __all__.

Consider simplifying to use only the recursive walk, which will discover both top-level and nested modules:

 __all__ = []
 
-for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
-    __all__.append(module_name)
-    importlib.import_module(f"{__name__}.{module_name}")
-
 # Recursively import subpackages and modules so their side-effects (e.g.,
 # op registrations) are applied even when nested in subdirectories.
 for _, full_name, _ in pkgutil.walk_packages(__path__, prefix=f"{__name__}."):
-    __all__.append(full_name)
+    # Extract relative module name for __all__
+    relative_name = full_name.split(".", 1)[1] if "." in full_name else full_name
+    if relative_name not in __all__:
+        __all__.append(relative_name)
     importlib.import_module(full_name)

Alternatively, keep both loops but avoid duplicates in __all__ by checking membership before appending.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py (1)

198-201: Unpacked use_initial_states is never validated.

The prepare_metadata operation now returns 4 values including use_initial_states, but the test doesn't validate this value. Either add an assertion to verify its value or prefix with underscore to indicate it's intentionally unused.

Apply this diff to indicate the variable is intentionally unused:

-    seq_len_s, seq_start, slot_s, use_initial_states = out
+    seq_len_s, seq_start, slot_s, _ = out

Or add validation:

     seq_len_s, seq_start, slot_s, use_initial_states = out
     assert seq_len_s.numel() == 2 and slot_s.numel() == 2
     assert torch.all(seq_start == torch.tensor([0, 2], device=device, dtype=seq_start.dtype))
+    assert use_initial_states.shape == torch.Size([2])
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (2)

70-83: Numerics/dtype: prefill computes in fp32 but returns fp32.

Downstream cached paths cast outputs back to input dtype; this uncached op does not. Either cast y to hidden_states.dtype before returning, or clearly document that this op’s contract is fp32.

Example:

-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor]:
@@
-    return y, ssm_state
+    return y.to(hidden_states.dtype), ssm_state.to(hidden_states.dtype)

183-194: Meta kernel shape/dtype LGTM; small nit.

Meta returns float32 like the kernel. Consider preserving hidden_states.dtype if you decide to align runtime output dtype.

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (2)

77-80: prepare_metadata now returns 4 values; update docstring and callers.

Docstring still says 3‑tuple. Update to mention (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states).

-    """Prepare metadata for cached causal conv (CUDA backend).
-
-    Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
-    """
+    """Prepare metadata for cached causal conv (CUDA backend).
+
+    Returns (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states).
+    """

221-221: Ruff ARG001: unused use_initial_states in fake.

Silence by prefixing with underscore or referencing it.

-def _cuda_cached_causal_conv1d_fake(
+def _cuda_cached_causal_conv1d_fake(
@@
-    use_initial_states: torch.Tensor,  # [num_seq]
+    use_initial_states: torch.Tensor,  # [num_seq]
@@
-):
+):
+    _ = use_initial_states  # silence ARG001 in fake
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (2)

100-110: Initial state gating and chunk index computation.

Logic is sound. Minor: using scalar 0 in torch.where is fine; consider zeros_like for dtype/device explicitness.

-            initial_states = torch.where(
+            zeros = torch.zeros_like(ssm_state_cache[slot_idx[:num_prefill]])
+            initial_states = torch.where(
                 use_initial_states[:num_prefill, None, None, None],
-                ssm_state_cache[slot_idx[:num_prefill]],
-                0,
+                ssm_state_cache[slot_idx[:num_prefill]],
+                zeros,
             )

173-200: Fake path includes use_initial_states; silence ARG001 if linted.

-):
+):
+    _ = use_initial_states
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (1)

114-144: prepare_metadata docstring still says 3‑tuple; now returns 4.

Update the docstring to reflect (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states).

-    """Prepare metadata for cached SSM transform.
-
-    Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
-    """
+    """Prepare metadata for cached SSM transform.
+
+    Returns (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states).
+    """
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8523168 and 4de161f.

📒 Files selected for processing (13)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (7 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (8 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (8 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py (0 hunks)
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py (5 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py (5 hunks)
💤 Files with no reviewable changes (1)
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
🧬 Code graph analysis (6)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • extract_op_args (407-444)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (2)
  • _torch_causal_conv1d_decode (83-133)
  • _torch_causal_conv1d_prefill (49-80)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (2)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • extract_op_args (407-444)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
  • input_pos (297-298)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (1)
  • _torch_cached_ssm_decode (31-102)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (1)
  • _torch_ssm_prefill (70-162)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (4)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • extract_op_args (407-444)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (1)
  • _torch_ssm_prefill (70-162)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (3)
  • input_pos (297-298)
  • MHACallable (692-696)
  • PrepareMetadataCallable (699-709)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (2)
  • get_cached_attention_op (227-228)
  • get_prepare_metadata_op (231-233)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (3)
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)
  • cu_seqlens_to_chunk_indices_offsets (24-85)
tensorrt_llm/_torch/modules/mamba/ssd_combined.py (1)
  • mamba_chunk_scan_combined (183-252)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • extract_op_args (407-444)
🪛 Ruff (0.14.0)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py

87-87: Unused noqa directive (non-enabled: E501)

Remove unused noqa directive

(RUF100)


150-150: Unused noqa directive (non-enabled: E501)

Remove unused noqa directive

(RUF100)

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py

199-199: Unpacked variable use_initial_states is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py

221-221: Unused function argument: use_initial_states

(ARG001)

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py

98-98: Unused noqa directive (non-enabled: E501)

Remove unused noqa directive

(RUF100)


170-170: Unused noqa directive (non-enabled: E501)

Remove unused noqa directive

(RUF100)

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

187-187: Unused function argument: use_initial_states

(ARG001)

🔇 Additional comments (33)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (1)

19-20: LGTM! Import paths updated correctly.

The relative import paths have been adjusted to reflect the new module structure with the mamba subpackage.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py (2)

87-90: LGTM! Import path updated to reflect new mamba subpackage.

The reference to _torch_causal_conv1d_decode has been correctly updated to the new module path under mamba.torch_backend_causal_conv.

Note: The # noqa: E501 directive on line 87 can be removed as it's unused per static analysis.


150-153: LGTM! Import path updated to reflect new mamba subpackage.

The reference to _torch_causal_conv1d_prefill has been correctly updated to the new module path under mamba.torch_backend_causal_conv.

Note: The # noqa: E501 directive on line 150 can be removed as it's unused per static analysis.

tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py (2)

103-121: LGTM! Operator renamed from torch_cached_ssm_transform to torch_cached_ssm.

The cached SSM operator call has been updated to use the new naming convention. The arguments and logic remain unchanged.


123-133: LGTM! Operator renamed from torch_ssm_transform to torch_ssm.

The uncached SSM operator call has been updated to use the new naming convention. The arguments and logic remain unchanged.

tests/integration/defs/accuracy/test_llm_api_autodeploy.py (2)

121-121: Backend changed from "torch-opt" to "torch-cudagraph" for NemotronH.

The compile backend has been updated for the Nemotron-H SSM model. Verify that this change is intentional and that "torch-cudagraph" is the correct backend for SSM models with chunked prefill enabled.

Note: TestLlama3_1_8B at line 61 still uses "torch-opt" - consider whether this inconsistency is intentional or if both should use the same backend.


143-145: Chunked prefill skip removed for SSM model testing.

The skip for chunked prefill has been commented out, enabling this test to run with enable_chunked_prefill=True. This aligns with the PR objective to enable chunked prefill for SSMs.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py (5)

73-91: LGTM! Operator renamed from torch_cached_ssm_transform to torch_cached_ssm.

The cached SSM operator call has been updated to use the new naming convention in the generate-only test path.


98-100: LGTM! Helper function renamed to _torch_cached_ssm_decode.

The decode helper reference has been updated to match the new naming convention.

Note: The # noqa: E501 directive on line 98 can be removed as it's unused per static analysis.


139-157: LGTM! Operator renamed from torch_cached_ssm_transform to torch_cached_ssm.

The cached SSM operator call has been updated to use the new naming convention in the context flattened test path.


170-172: LGTM! Helper function renamed to _torch_ssm_prefill.

The prefill helper reference has been updated to match the new naming convention under the mamba subpackage.

Note: The # noqa: E501 directive on line 170 can be removed as it's unused per static analysis.


203-206: LGTM! Metadata preparation updated to return 4-tuple with use_initial_states.

The test correctly unpacks the expanded prepare_metadata return value that now includes use_initial_states as the 4th element, and validates the expected number of outputs.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py (2)

56-70: LGTM! Operator renamed from torch_cached_ssm_transform to torch_cached_ssm.

The torch reference operator call has been updated to use the new naming convention.


73-87: LGTM! Operator renamed from triton_cached_ssm_transform to triton_cached_ssm.

The Triton operator call has been updated to use the new naming convention in the generate-only test.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py (1)

62-85: LGTM! use_initial_states parameter added to CUDA cached causal conv path.

The new use_initial_states parameter has been correctly added to the operator call, aligning with the expanded signature for chunked prefill support.

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (1)

165-179: Op name alignment OK, but ensure consistency across backends.

This registers auto_deploy::torch_ssm. Verify all descriptors refer to this (see TorchBackendSSM below for a mismatch).

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (5)

21-23: Import path hygiene LGTM.


92-93: Fake path output shape updated — good.

Shapes/dtypes match runtime path (bool).


106-107: New arg use_initial_states — verify schema and interface.

Arg is correctly threaded and used only in prefill. Ensure AttentionDescriptor.get_prepare_metadata_op arity is updated (it is below).


167-168: Correctly maps use_initial_states → has_initial_state.

Good propagation into causal_conv1d_fn.


263-265: prepare_metadata arity advertised as 4 — consistent with implementation.

Good.

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (5)

8-8: New helper import (cu_seqlens_to_chunk_indices_offsets) — OK.


27-47: Custom op rename to auto_deploy::triton_cached_ssm looks good.

Signature includes use_initial_states; matches descriptor wiring below.


120-124: Plumbing chunk_indices/offsets into mamba_chunk_scan_combined — LGTM.


224-229: Source/cached op bindings correct for Triton.

Returns torch.ops.auto_deploy.torch_ssm and triton_cached_ssm — consistent with registrations.


232-234: Descriptor metadata arity updated to 4 — good.

Matches torch_ssm_prepare_metadata outputs.

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (7)

15-27: Import path cleanups — OK.


28-28: Directly using _torch_ssm_prefill from torch_mamba.py — OK.

Keep internal helper separation clear in docs/comments.


31-43: Decode path math LGTM; compute is vectorized and caches updated in-place.

Small nit: consider .to(dtype) consistency for updated_ssm_state if cache dtype differs; you already cast on scatter.


139-144: use_initial_states derivation looks correct.

input_pos > 0 matches “use cached initial states for nonzero starting positions”.


161-180: Cached op (torch_cached_ssm) — prefill/decode split and cache updates LGTM.

Prefill uses _torch_ssm_prefill; outputs are cast back to input dtype before scatter — good.


319-322: Descriptor metadata arity updated to 4 — consistent.

Good alignment with torch_ssm_prepare_metadata outputs.


358-364: Constant extraction via extract_op_args — OK.

Matches op schema (time_step_limit, chunk_size).

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
@tensorrt-cicd
Copy link
Collaborator

PR_Github #21789 [ run ] completed with state SUCCESS. Commit: 4de161f
/LLM/main/L0_MergeRequest_PR pipeline #16426 completed with status: 'FAILURE'

@suyoggupta
Copy link
Collaborator Author

/bot run

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21812 [ run ] triggered by Bot. Commit: 7708f80

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21812 [ run ] completed with state SUCCESS. Commit: 7708f80
/LLM/main/L0_MergeRequest_PR pipeline #16441 completed with status: 'FAILURE'

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21884 [ run ] triggered by Bot. Commit: 9f535df

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21884 [ run ] completed with state SUCCESS. Commit: 9f535df
/LLM/main/L0_MergeRequest_PR pipeline #16498 completed with status: 'FAILURE'

Copy link
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great. Just two minor suggestions

@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21932 [ run ] triggered by Bot. Commit: d323315

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21931 [ run ] triggered by Bot. Commit: d323315

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21932 [ run ] completed with state ABORTED. Commit: d323315

@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21935 [ run ] triggered by Bot. Commit: d323315

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21931 [ run ] completed with state ABORTED. Commit: d323315
LLM/main/L0_MergeRequest_PR #16533 (Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21935 [ run ] completed with state SUCCESS. Commit: d323315
/LLM/main/L0_MergeRequest_PR pipeline #16536 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@suyoggupta
Copy link
Collaborator Author

/bot --reuse-pipeline

@github-actions
Copy link

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@suyoggupta
Copy link
Collaborator Author

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21947 [ reuse-pipeline ] triggered by Bot. Commit: 0b8d35a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21947 [ reuse-pipeline ] completed with state SUCCESS. Commit: 0b8d35a
Reusing PR_Github #21935 for commit 0b8d35a

@suyoggupta suyoggupta merged commit 7050b1e into NVIDIA:main Oct 20, 2025
5 checks passed
@github-project-automation github-project-automation bot moved this from Backlog to Done in AutoDeploy Board Oct 20, 2025
govind-ramnarayan pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Oct 21, 2025
…IDIA#8477)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request Oct 24, 2025
…IDIA#8477)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 1, 2025
…IDIA#8477)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
…IDIA#8477)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
…IDIA#8477)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
…IDIA#8477)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

4 participants