-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[#8272][feat] Enable chunked prefill for SSMs in AutoDeploy #8477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[#8272][feat] Enable chunked prefill for SSMs in AutoDeploy #8477
Conversation
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
|
/bot run |
|
PR_Github #21789 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR refactors Mamba-related custom operations across multiple backends (Torch, CUDA, Triton), introducing a new Changes
Sequence Diagram(s)sequenceDiagram
participant Test
participant PrepareMetadata
participant CachedOp
participant Compute
Test->>PrepareMetadata: Call prepare_metadata(input_pos, ...)
rect rgb(220, 240, 220)
Note over PrepareMetadata: Compute use_initial_states from<br/>input_pos > 0
end
PrepareMetadata-->>Test: (seq_len, seq_start, slot_idx, use_initial_states)
Test->>CachedOp: Call cached_op(..., use_initial_states)
alt use_initial_states == True
CachedOp->>Compute: Use initial states from cache
else
CachedOp->>Compute: Use runtime-computed states
end
Compute-->>CachedOp: result
CachedOp-->>Test: output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~35 minutes The changes involve a well-defined, consistent pattern across multiple files: function/operator renaming, parameter threading for Possibly related PRs
Suggested labels
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py (1)
125-159: Adduse_initial_statesparameter totorch_cached_ssmoperator.Verification confirms the operators have asymmetric signatures by design. The
triton_cached_ssmoperator includesuse_initial_statesas a required parameter, whiletorch_cached_ssmlacks it entirely. The test correctly reflects this disparity (lines 127–141 omit it for torch; lines 144–159 include it for triton), but the underlying inconsistency must be resolved.Update both
_torch_cached_ssmand_torch_cached_ssm_fakeintensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.pyto acceptuse_initial_states: torch.Tensoras a parameter (positioned afterslot_idx, beforessm_state_cache, consistent with the triton signature).tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (1)
1-1: Missing NVIDIA Apache-2.0 header (2025).Per guidelines, prepend the 2025 NVIDIA Apache-2.0 header to all source files.
Apply at file top.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (1)
1-1: Missing NVIDIA Apache-2.0 header (2025).Add the required copyright header.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (1)
1-1: Missing NVIDIA Apache-2.0 header (2025).Add the required header.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (1)
1-1: Missing NVIDIA Apache-2.0 header (2025).Add the header at the top of the file.
🧹 Nitpick comments (9)
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1)
12-16: Potential duplicate imports between top-level and recursive discovery.The initial loop (lines 8-10) imports all top-level modules in
__path__, then the recursive walk (lines 14-16) walks all packages including the same top-level modules. This results in duplicate imports and duplicate entries in__all__.Consider simplifying to use only the recursive walk, which will discover both top-level and nested modules:
__all__ = [] -for _, module_name, is_pkg in pkgutil.iter_modules(__path__): - __all__.append(module_name) - importlib.import_module(f"{__name__}.{module_name}") - # Recursively import subpackages and modules so their side-effects (e.g., # op registrations) are applied even when nested in subdirectories. for _, full_name, _ in pkgutil.walk_packages(__path__, prefix=f"{__name__}."): - __all__.append(full_name) + # Extract relative module name for __all__ + relative_name = full_name.split(".", 1)[1] if "." in full_name else full_name + if relative_name not in __all__: + __all__.append(relative_name) importlib.import_module(full_name)Alternatively, keep both loops but avoid duplicates in
__all__by checking membership before appending.tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py (1)
198-201: Unpackeduse_initial_statesis never validated.The
prepare_metadataoperation now returns 4 values includinguse_initial_states, but the test doesn't validate this value. Either add an assertion to verify its value or prefix with underscore to indicate it's intentionally unused.Apply this diff to indicate the variable is intentionally unused:
- seq_len_s, seq_start, slot_s, use_initial_states = out + seq_len_s, seq_start, slot_s, _ = outOr add validation:
seq_len_s, seq_start, slot_s, use_initial_states = out assert seq_len_s.numel() == 2 and slot_s.numel() == 2 assert torch.all(seq_start == torch.tensor([0, 2], device=device, dtype=seq_start.dtype)) + assert use_initial_states.shape == torch.Size([2])tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (2)
70-83: Numerics/dtype: prefill computes in fp32 but returns fp32.Downstream cached paths cast outputs back to input dtype; this uncached op does not. Either cast y to hidden_states.dtype before returning, or clearly document that this op’s contract is fp32.
Example:
-) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: @@ - return y, ssm_state + return y.to(hidden_states.dtype), ssm_state.to(hidden_states.dtype)
183-194: Meta kernel shape/dtype LGTM; small nit.Meta returns float32 like the kernel. Consider preserving hidden_states.dtype if you decide to align runtime output dtype.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (2)
77-80: prepare_metadata now returns 4 values; update docstring and callers.Docstring still says 3‑tuple. Update to mention (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states).
- """Prepare metadata for cached causal conv (CUDA backend). - - Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). - """ + """Prepare metadata for cached causal conv (CUDA backend). + + Returns (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states). + """
221-221: Ruff ARG001: unused use_initial_states in fake.Silence by prefixing with underscore or referencing it.
-def _cuda_cached_causal_conv1d_fake( +def _cuda_cached_causal_conv1d_fake( @@ - use_initial_states: torch.Tensor, # [num_seq] + use_initial_states: torch.Tensor, # [num_seq] @@ -): +): + _ = use_initial_states # silence ARG001 in faketensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (2)
100-110: Initial state gating and chunk index computation.Logic is sound. Minor: using scalar 0 in torch.where is fine; consider zeros_like for dtype/device explicitness.
- initial_states = torch.where( + zeros = torch.zeros_like(ssm_state_cache[slot_idx[:num_prefill]]) + initial_states = torch.where( use_initial_states[:num_prefill, None, None, None], - ssm_state_cache[slot_idx[:num_prefill]], - 0, + ssm_state_cache[slot_idx[:num_prefill]], + zeros, )
173-200: Fake path includes use_initial_states; silence ARG001 if linted.-): +): + _ = use_initial_statestensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (1)
114-144: prepare_metadata docstring still says 3‑tuple; now returns 4.Update the docstring to reflect (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states).
- """Prepare metadata for cached SSM transform. - - Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). - """ + """Prepare metadata for cached SSM transform. + + Returns (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states). + """
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py(7 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py(2 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py(8 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py(3 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py(8 hunks)tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py(2 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py(0 hunks)tests/integration/defs/accuracy/test_llm_api_autodeploy.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py(3 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py(5 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py(5 hunks)
💤 Files with no reviewable changes (1)
- tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.pytensorrt_llm/_torch/auto_deploy/models/patches/bamba.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.pytests/integration/defs/accuracy/test_llm_api_autodeploy.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.pytensorrt_llm/_torch/auto_deploy/models/patches/bamba.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.pytests/integration/defs/accuracy/test_llm_api_autodeploy.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.pytensorrt_llm/_torch/auto_deploy/models/patches/bamba.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.pytests/integration/defs/accuracy/test_llm_api_autodeploy.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
🧬 Code graph analysis (6)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
extract_op_args(407-444)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (2)
_torch_causal_conv1d_decode(83-133)_torch_causal_conv1d_prefill(49-80)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (2)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
extract_op_args(407-444)tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
input_pos(297-298)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (1)
_torch_cached_ssm_decode(31-102)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (1)
_torch_ssm_prefill(70-162)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (4)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
extract_op_args(407-444)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (1)
_torch_ssm_prefill(70-162)tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (3)
input_pos(297-298)MHACallable(692-696)PrepareMetadataCallable(699-709)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (2)
get_cached_attention_op(227-228)get_prepare_metadata_op(231-233)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (3)
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)
cu_seqlens_to_chunk_indices_offsets(24-85)tensorrt_llm/_torch/modules/mamba/ssd_combined.py (1)
mamba_chunk_scan_combined(183-252)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
extract_op_args(407-444)
🪛 Ruff (0.14.0)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py
87-87: Unused noqa directive (non-enabled: E501)
Remove unused noqa directive
(RUF100)
150-150: Unused noqa directive (non-enabled: E501)
Remove unused noqa directive
(RUF100)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py
199-199: Unpacked variable use_initial_states is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
221-221: Unused function argument: use_initial_states
(ARG001)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py
98-98: Unused noqa directive (non-enabled: E501)
Remove unused noqa directive
(RUF100)
170-170: Unused noqa directive (non-enabled: E501)
Remove unused noqa directive
(RUF100)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
187-187: Unused function argument: use_initial_states
(ARG001)
🔇 Additional comments (33)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (1)
19-20: LGTM! Import paths updated correctly.The relative import paths have been adjusted to reflect the new module structure with the
mambasubpackage.tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py (2)
87-90: LGTM! Import path updated to reflect new mamba subpackage.The reference to
_torch_causal_conv1d_decodehas been correctly updated to the new module path undermamba.torch_backend_causal_conv.Note: The
# noqa: E501directive on line 87 can be removed as it's unused per static analysis.
150-153: LGTM! Import path updated to reflect new mamba subpackage.The reference to
_torch_causal_conv1d_prefillhas been correctly updated to the new module path undermamba.torch_backend_causal_conv.Note: The
# noqa: E501directive on line 150 can be removed as it's unused per static analysis.tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py (2)
103-121: LGTM! Operator renamed fromtorch_cached_ssm_transformtotorch_cached_ssm.The cached SSM operator call has been updated to use the new naming convention. The arguments and logic remain unchanged.
123-133: LGTM! Operator renamed fromtorch_ssm_transformtotorch_ssm.The uncached SSM operator call has been updated to use the new naming convention. The arguments and logic remain unchanged.
tests/integration/defs/accuracy/test_llm_api_autodeploy.py (2)
121-121: Backend changed from "torch-opt" to "torch-cudagraph" for NemotronH.The compile backend has been updated for the Nemotron-H SSM model. Verify that this change is intentional and that "torch-cudagraph" is the correct backend for SSM models with chunked prefill enabled.
Note: TestLlama3_1_8B at line 61 still uses "torch-opt" - consider whether this inconsistency is intentional or if both should use the same backend.
143-145: Chunked prefill skip removed for SSM model testing.The skip for chunked prefill has been commented out, enabling this test to run with
enable_chunked_prefill=True. This aligns with the PR objective to enable chunked prefill for SSMs.tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py (5)
73-91: LGTM! Operator renamed fromtorch_cached_ssm_transformtotorch_cached_ssm.The cached SSM operator call has been updated to use the new naming convention in the generate-only test path.
98-100: LGTM! Helper function renamed to_torch_cached_ssm_decode.The decode helper reference has been updated to match the new naming convention.
Note: The
# noqa: E501directive on line 98 can be removed as it's unused per static analysis.
139-157: LGTM! Operator renamed fromtorch_cached_ssm_transformtotorch_cached_ssm.The cached SSM operator call has been updated to use the new naming convention in the context flattened test path.
170-172: LGTM! Helper function renamed to_torch_ssm_prefill.The prefill helper reference has been updated to match the new naming convention under the mamba subpackage.
Note: The
# noqa: E501directive on line 170 can be removed as it's unused per static analysis.
203-206: LGTM! Metadata preparation updated to return 4-tuple withuse_initial_states.The test correctly unpacks the expanded prepare_metadata return value that now includes
use_initial_statesas the 4th element, and validates the expected number of outputs.tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py (2)
56-70: LGTM! Operator renamed fromtorch_cached_ssm_transformtotorch_cached_ssm.The torch reference operator call has been updated to use the new naming convention.
73-87: LGTM! Operator renamed fromtriton_cached_ssm_transformtotriton_cached_ssm.The Triton operator call has been updated to use the new naming convention in the generate-only test.
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py (1)
62-85: LGTM!use_initial_statesparameter added to CUDA cached causal conv path.The new
use_initial_statesparameter has been correctly added to the operator call, aligning with the expanded signature for chunked prefill support.tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py (1)
165-179: Op name alignment OK, but ensure consistency across backends.This registers auto_deploy::torch_ssm. Verify all descriptors refer to this (see TorchBackendSSM below for a mismatch).
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (5)
21-23: Import path hygiene LGTM.
92-93: Fake path output shape updated — good.Shapes/dtypes match runtime path (bool).
106-107: New arg use_initial_states — verify schema and interface.Arg is correctly threaded and used only in prefill. Ensure AttentionDescriptor.get_prepare_metadata_op arity is updated (it is below).
167-168: Correctly maps use_initial_states → has_initial_state.Good propagation into causal_conv1d_fn.
263-265: prepare_metadata arity advertised as 4 — consistent with implementation.Good.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (5)
8-8: New helper import (cu_seqlens_to_chunk_indices_offsets) — OK.
27-47: Custom op rename to auto_deploy::triton_cached_ssm looks good.Signature includes use_initial_states; matches descriptor wiring below.
120-124: Plumbing chunk_indices/offsets into mamba_chunk_scan_combined — LGTM.
224-229: Source/cached op bindings correct for Triton.Returns torch.ops.auto_deploy.torch_ssm and triton_cached_ssm — consistent with registrations.
232-234: Descriptor metadata arity updated to 4 — good.Matches torch_ssm_prepare_metadata outputs.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (7)
15-27: Import path cleanups — OK.
28-28: Directly using _torch_ssm_prefill from torch_mamba.py — OK.Keep internal helper separation clear in docs/comments.
31-43: Decode path math LGTM; compute is vectorized and caches updated in-place.Small nit: consider .to(dtype) consistency for updated_ssm_state if cache dtype differs; you already cast on scatter.
139-144: use_initial_states derivation looks correct.input_pos > 0 matches “use cached initial states for nonzero starting positions”.
161-180: Cached op (torch_cached_ssm) — prefill/decode split and cache updates LGTM.Prefill uses _torch_ssm_prefill; outputs are cast back to input dtype before scatter — good.
319-322: Descriptor metadata arity updated to 4 — consistent.Good alignment with torch_ssm_prepare_metadata outputs.
358-364: Constant extraction via extract_op_args — OK.Matches op schema (time_step_limit, chunk_size).
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
Outdated
Show resolved
Hide resolved
|
PR_Github #21789 [ run ] completed with state |
|
/bot run |
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
|
/bot run |
|
PR_Github #21812 [ run ] triggered by Bot. Commit: |
|
PR_Github #21812 [ run ] completed with state |
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
|
PR_Github #21884 [ run ] triggered by Bot. Commit: |
|
PR_Github #21884 [ run ] completed with state |
lucaslie
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great. Just two minor suggestions
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
Show resolved
Hide resolved
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
|
/bot run |
|
PR_Github #21932 [ run ] triggered by Bot. Commit: |
|
PR_Github #21931 [ run ] triggered by Bot. Commit: |
|
PR_Github #21932 [ run ] completed with state |
|
/bot run |
|
PR_Github #21935 [ run ] triggered by Bot. Commit: |
|
PR_Github #21931 [ run ] completed with state |
|
PR_Github #21935 [ run ] completed with state |
|
/bot --reuse-pipeline |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand.
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot reuse-pipeline |
|
PR_Github #21947 [ reuse-pipeline ] triggered by Bot. Commit: |
|
PR_Github #21947 [ reuse-pipeline ] completed with state |
…IDIA#8477) Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
…IDIA#8477) Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Signed-off-by: yufeiwu-nv <230315618+yufeiwu-nv@users.noreply.github.com>
…IDIA#8477) Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
…IDIA#8477) Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
…IDIA#8477) Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
…IDIA#8477) Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
mambasub directory#8272