-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[TRTLLM-8535][feat] Support DeepSeek V3.2 with FP8 + BF16 KV cache/NVFP4 + BF16 KV cache #8405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughThis PR adds Dense Sparse Attention (DSA) support via FlashMLA integration, introduces DeepSeek-V3.2 model configuration, and extends the build system with new sparse attention backend options. Core changes include FlashMLA submodule integration, new DSA attention kernels and cache management, DeepSeek-V3.2 config classes, and corresponding CLI/API updates across examples and llmapi. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant LLM
participant ModelEngine
participant Executor
participant AttentionBackend
participant DSAIndexer
participant FlashMLA
User->>LLM: initialize(model_id, sparse_attn_config=DSA)
activate LLM
LLM->>ModelEngine: create with sparse_attention_config
activate ModelEngine
ModelEngine->>ModelEngine: load sparse_attention_config from model.model_config
ModelEngine->>AttentionBackend: get_attention_backend(..., sparse_attention_config=DSA)
deactivate ModelEngine
deactivate LLM
User->>LLM: generate(prompts)
activate LLM
LLM->>Executor: execute with attention backend
activate Executor
Note over Executor: Prefill phase
Executor->>DSAIndexer: prepare(sequences, metadata)
activate DSAIndexer
DSAIndexer->>DSAIndexer: split_prefill_chunks()
DSAIndexer->>DSAIndexer: build IndexerPrefillChunkMetadata
deactivate DSAIndexer
Note over Executor: Forward pass
Executor->>AttentionBackend: forward(hidden_states, attention_metadata)
activate AttentionBackend
AttentionBackend->>AttentionBackend: forward_impl_with_dsa()
AttentionBackend->>AttentionBackend: compute Q, K projections + RoPE
AttentionBackend->>AttentionBackend: FP8 quantize K
AttentionBackend->>DSAIndexer: sparse_attn_indexer (top-k index generation)
activate DSAIndexer
DSAIndexer->>FlashMLA: flash_mla_sparse_fwd(q, kv, indices)
activate FlashMLA
FlashMLA-->>DSAIndexer: output, max_logits, lse
deactivate FlashMLA
deactivate DSAIndexer
AttentionBackend-->>Executor: attention output
deactivate AttentionBackend
Executor-->>LLM: generated_tokens
deactivate Executor
LLM-->>User: results
deactivate LLM
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Rationale: This PR introduces significant new functionality spanning multiple interconnected areas: (1) new sparse attention backend (DSA) with multiple classes, Triton kernels, and cache management logic requiring understanding of both high-level design and low-level CUDA/quantization details; (2) new model configuration class and registry-based loading mechanism; (3) substantial changes to attention module implementation with multiple new methods; (4) refactoring of cache managers across sparse backends (Rocket, DSA); (5) executor/engine modifications threading sparse config throughout; (6) comprehensive test additions requiring validation of numerical correctness. While many changes follow repetitive patterns (e.g., adding similar parameters across multiple examples), the heterogeneous nature of logic changes (kernel implementation, cache management, config dispatch, attention forward paths) and the density of new public APIs and classes demand careful, multi-perspective review beyond pattern matching. Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 24
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (10)
examples/longbench/eval_longbench_v1.py (1)
1-11: Add required NVIDIA Apache-2.0 copyright header.Per coding guidelines, all Python source files must include the NVIDIA Apache-2.0 copyright header with the current year at the top of the file.
As per coding guidelines.
Add the copyright header before the shebang:
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + #!/usr/bin/env python3 """ LongBench v1 evaluation script with TensorRT-LLM and sparse attention.tensorrt_llm/_torch/speculative/mtp.py (4)
74-76: Invalid Tensor.copy_ usage; use fill_/zero_.Tensor.copy_ expects a Tensor source. Passing a Python int will error at runtime when relaxed acceptance is enabled.
Apply:
- self.mtp_relaxed_delta_pool[slot_id].copy_(0, non_blocking=True) + self.mtp_relaxed_delta_pool[slot_id].fill_(0) ... - self.mtp_relaxed_delta_pool[free_slot_id].copy_(0, - non_blocking=True) + self.mtp_relaxed_delta_pool[free_slot_id].fill_(0)Also applies to: 83-85
1221-1274: Bug: Always using mtp_layers[0] in loop; should index layer i.This reuses layer 0 for all steps, breaking multi‑layer MTP logic and accuracy.
Apply:
- hidden_states = draft_model.mtp_layers[0]( + hidden_states = draft_model.mtp_layers[i]( embed_tokens=draft_model.embed_tokens, all_rank_num_tokens=spec_metadata.all_rank_num_tokens, **inputs) ... - hidden_states = draft_model.mtp_layers[0]( + hidden_states = draft_model.mtp_layers[i]( embed_tokens=draft_model.embed_tokens, all_rank_num_tokens=spec_metadata.subseq_all_rank_num_tokens, **inputs) ... - logits = draft_model.mtp_layers[0].shared_head( + logits = draft_model.mtp_layers[i].shared_head( padded_hidden_states, draft_model.lm_head, attn_metadata, True) ... - logits = draft_model.mtp_layers[0].shared_head( + logits = draft_model.mtp_layers[i].shared_head( hidden_states[gather_ids], draft_model.lm_head, attn_metadata, True) ... - mapping_lm_head_tp = draft_model.mtp_layers[0].shared_head.mapping_lm_head_tp + mapping_lm_head_tp = draft_model.mtp_layers[i].shared_head.mapping_lm_head_tp
1265-1267: Fix error message condition.The error triggers when token_count > max_num_requests, but the message says “<”.
- f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported" + f"In MTPEagleWorker.forward(), token_count > max_num_requests, which is not supported"
890-915: Call update_for_spec_dec after mutating seq_lens/kv_lens in base MTPWorker.Eagle path updates via update_for_spec_dec; the base MTPWorker does not, which may desync DSASparse/FlashMLA metadata when used without Eagle.
Minimal fix:
def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, attn_metadata: AttentionMetadata): attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") ... attn_metadata.on_update() + # Ensure derived buffers are refreshed for spec-dec backends (DSA/FlashMLA). + if hasattr(attn_metadata, "update_for_spec_dec"): + attn_metadata.update_for_spec_dec() if hasattr(attn_metadata, 'kv_lens_cuda'): ... + if hasattr(attn_metadata, "update_for_spec_dec"): + attn_metadata.update_for_spec_dec()Based on learnings.
tensorrt_llm/_torch/attention_backend/utils.py (1)
1-13: Fix forward-ref type to satisfy Ruff (F821).Import SparseAttentionConfig under TYPE_CHECKING to avoid undefined-name while keeping runtime clean.
Apply this diff:
+from typing import Optional, Type, TYPE_CHECKING +if TYPE_CHECKING: + from .interface import SparseAttentionConfig -from typing import Optional, Typecpp/tensorrt_llm/CMakeLists.txt (1)
186-216: Potential link error when BUILD_FLASH_MLA=OFF.flash_mla_src is unconditionally in TRTLLM_LINK_LIBS; if BUILD_FLASH_MLA is OFF, target won't exist → link failure.
Apply this diff to remove it from the unconditional list:
set(TRTLLM_LINK_LIBS @@ - flash_mla_src @@ )Then append it conditionally after adding the subdir (see next comment).
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
1-27: Add NVIDIA Apache-2.0 header before third-party MIT notice.Per coding guidelines, prepend current-year NVIDIA Apache-2.0 header while keeping upstream MIT attribution.
examples/llm-api/llm_sparse_attention.py (1)
1-14: Add NVIDIA Apache-2.0 header at file top.Per coding guidelines, prepend the standard header.
tensorrt_llm/_torch/attention_backend/sparse/rocket.py (1)
552-559: RocketVanillaAttentionMetadata.prepare calls missing get_kt_block_offsetsRocketKVCacheManager no longer provides get_kt_block_offsets; this will raise at runtime. Mirror the TRTLLM path: copy into pinned host tensor, then device copy.
Apply:
def prepare(self) -> None: super().prepare() @@ - if self.kv_cache_manager is not None: - # for kt cache - self.host_kt_cache_block_offsets = self.kv_cache_manager.get_kt_block_offsets( - self.request_ids) - self.kt_cache_block_offsets[:self.num_seqs].copy_( - self.host_kt_cache_block_offsets[:self.num_seqs], - non_blocking=True) + if self.kv_cache_manager is not None: + # for kt cache: copy offsets via BlockManager-backed API + if not hasattr(self, "host_kt_cache_block_offsets"): + self.host_kt_cache_block_offsets = torch.zeros_like( + self.kt_cache_block_offsets, device="cpu", pin_memory=True) + self.kv_cache_manager.copy_kt_block_offsets( + self.request_ids, self.host_kt_cache_block_offsets) + self.kt_cache_block_offsets[:self.num_seqs].copy_( + self.host_kt_cache_block_offsets[:self.num_seqs], non_blocking=True)
🧹 Nitpick comments (29)
tensorrt_llm/_torch/speculative/mtp.py (1)
1163-1171: torch.compile on helper that appends to Python list may induce graph breaks.update_draft_tokens mutates a Python list; torch.compile may retrace/graph-break. Safe but adds overhead.
Inline the append outside the compiled region, or return the token and append in Python.
tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py (5)
149-154: Remove unused parameters from reference helpers and update call sitesDrops kv_lora_rank, v_head_dim, device, and query_lens that are unused; simplifies signatures and silences Ruff ARG001.
-def calculate_reference_output_prefill_only(q_c, kv_c, k_pe, W_UK, W_UV, - rope_cos_sin, sequence_lengths, - num_heads, kv_lora_rank, - qk_nope_head_dim, qk_rope_head_dim, - v_head_dim, softmax_scale, device): +def calculate_reference_output_prefill_only(q_c, kv_c, k_pe, W_UK, W_UV, + rope_cos_sin, sequence_lengths, + num_heads, qk_nope_head_dim, + qk_rope_head_dim, softmax_scale): -def calculate_reference_output_generation(q_c, kv_c, k_pe, W_UK, W_UV, - kv_cache_lens, num_heads, - kv_lora_rank, qk_nope_head_dim, - qk_rope_head_dim, v_head_dim, - softmax_scale, device): +def calculate_reference_output_generation(q_c, kv_c, k_pe, W_UK, W_UV, + kv_cache_lens, num_heads, + qk_nope_head_dim, + qk_rope_head_dim, softmax_scale): -def calculate_reference_output_mixed(q_ctx, q_gen, kv_c_all, k_pe_all, W_UK, - W_UV, rope_cos_sin, ctx_indices, - gen_indices, seq_lens, query_lens, - num_heads, kv_lora_rank, qk_nope_head_dim, - qk_rope_head_dim, v_head_dim, - softmax_scale, device): +def calculate_reference_output_mixed(q_ctx, q_gen, kv_c_all, k_pe_all, W_UK, + W_UV, rope_cos_sin, ctx_indices, + gen_indices, seq_lens, num_heads, + qk_nope_head_dim, qk_rope_head_dim, + softmax_scale): - ctx_results = calculate_reference_output_prefill_only( - q_ctx, kv_c, k_pe, W_UK, W_UV, rope_cos_sin, - [seq_lens[i] - for i in ctx_indices], num_heads, kv_lora_rank, qk_nope_head_dim, - qk_rope_head_dim, v_head_dim, softmax_scale, device) + ctx_results = calculate_reference_output_prefill_only( + q_ctx, kv_c, k_pe, W_UK, W_UV, rope_cos_sin, + [seq_lens[i] for i in ctx_indices], + num_heads, qk_nope_head_dim, qk_rope_head_dim, softmax_scale) - gen_results = calculate_reference_output_generation( - q_gen, kv_c, k_pe, W_UK, W_UV, [seq_lens[i] for i in gen_indices], - num_heads, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, - v_head_dim, softmax_scale, device) + gen_results = calculate_reference_output_generation( + q_gen, kv_c, k_pe, W_UK, W_UV, + [seq_lens[i] for i in gen_indices], + num_heads, qk_nope_head_dim, qk_rope_head_dim, softmax_scale) - reference_output = calculate_reference_output_mixed( + reference_output = calculate_reference_output_mixed( q_ctx=q_ctx_ref, q_gen=q_gen_ref, kv_c_all=all_compressed_kv, k_pe_all=all_k_pe_for_ref, W_UK=W_UK, W_UV=W_UV, rope_cos_sin=rope_cos_sin, ctx_indices=list(range(num_contexts)), gen_indices=list(range(num_contexts, len(batch_order))), seq_lens=[seq_lens[i] for i in batch_order], - query_lens=batch_query_lens, num_heads=num_heads, - kv_lora_rank=kv_lora_rank, qk_nope_head_dim=qk_nope_head_dim, qk_rope_head_dim=qk_rope_head_dim, - v_head_dim=v_head_dim, softmax_scale=softmax_scale, - device=device, )Also applies to: 191-196, 229-236, 253-258, 268-271, 797-816
294-301: Avoid zip(strict) for Py3.8; assert lengths explicitlyProject targets Python 3.8+; zip(strict=…) is 3.10+. Add an explicit length assert to catch mismatches.
ctx_indices = [ - i for i, (q, s) in enumerate(zip(query_lens, seq_lens)) if q == s + i for i, (q, s) in enumerate(zip(query_lens, seq_lens)) if q == s ] gen_indices = [ - i for i, (q, s) in enumerate(zip(query_lens, seq_lens)) if q < s + i for i, (q, s) in enumerate(zip(query_lens, seq_lens)) if q < s ] + assert len(query_lens) == len(seq_lens), "seq_lens and query_lens length mismatch"Based on static analysis hints and Python 3.8 constraint.
685-705: Remove dead helper and unused expected indiceslocal_to_global_indices is unused; expected_* variables are assigned but not used. Clean up to silence F841/B007 and reduce noise.
- def local_to_global_indices(local_indices, - req_indices, - cache_offset_start=0): - """ - Transform indexer's local indices to global indices. - """ - global_indices = local_indices.clone() - kv_offset = cache_offset_start - token_idx = 0 - - for req_idx in req_indices: - num_tokens = query_lens[req_idx] - # Add offset for this request's cache position - for local_pos in range(num_tokens): - # Only transform non-padding indices (>= 0) - mask = global_indices[token_idx] >= 0 - global_indices[token_idx][mask] += kv_offset - token_idx += 1 - kv_offset += seq_lens[req_idx] - return global_indices - - # Create expected global indices (sorted) for validation (not used but can be used for validation) - expected_ctx_indices = create_causal_indices(ctx_indices, - cache_offset_start=0) + # Optionally validate indexer order/shape here if needed in future. - # Create expected global indices (sorted) for validation (not used but can be used for validation) - expected_gen_indices = create_causal_indices(gen_indices, - cache_offset_start=0)Also applies to: 715-717, 737-740
442-443: Remove f-strings without placeholdersSilence F541 and minor style nit.
- print(f" Testing single layer (baseline)") + print(" Testing single layer (baseline)") ... - print(f" Allocating and pre-populating cache...") + print(" Allocating and pre-populating cache...") ... - print(f" ✓ KV cache allocated and pre-populated") + print(" ✓ KV cache allocated and pre-populated")Also applies to: 505-505, 594-594
381-381: Remove no-op statementssum(seq_lens) and list(range(...)) are unused.
- sum(seq_lens) ... - list(range(batch_spec.batch_size))Also applies to: 498-498
examples/llm-api/quickstart_advanced.py (1)
88-89: tokens_per_block CLI wiring — consider auto‑detect.Default 32 is fine; optionally, if user omits it, read tokens_per_block from the engine/model config to prevent capacity mismatch across models.
I can draft a small helper to query ModelConfigCpp and default accordingly if you want.
Also applies to: 179-184
tensorrt_llm/_torch/attention_backend/trtllm.py (2)
1384-1390: Consume unused kwargs to satisfy lint (ARG002).Silence warnings and keep forward‑compat signature.
Apply this diff:
def load_paged_kv_cache_for_mla( @@ - **kwargs, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: @@ - assert out_dtype in [torch.float16, torch.bfloat16, torch.float32] + assert out_dtype in [torch.float16, torch.bfloat16, torch.float32] + if kwargs: + logger.debug("load_paged_kv_cache_for_mla: unused kwargs=%s", list(kwargs.keys()))
1472-1478: Consume unused kwargs to satisfy lint (ARG002).Same rationale as above.
Apply this diff:
def mla_rope_append_paged_kv_assign_q( @@ - **kwargs, + **kwargs, ) -> None: @@ - assert self.is_mla_enable and self.mla_params is not None + assert self.is_mla_enable and self.mla_params is not None + if kwargs: + logger.debug("mla_rope_append_paged_kv_assign_q: unused kwargs=%s", list(kwargs.keys()))tensorrt_llm/_torch/model_config.py (2)
14-21: Duplicate logger import.Remove the module‑level import; keep the logger instance import.
Apply this diff:
-from tensorrt_llm import logger @@ -from tensorrt_llm.logger import logger +from tensorrt_llm.logger import logger
420-468: from_pretrained: registry path — defensiveness.Good use of PretrainedConfig.get_config_dict and fallback. Consider passing through trust_remote_code consistently; also avoid mutating kwargs in place by using a local var for sparse_attention_config (optional).
tensorrt_llm/_torch/pyexecutor/_util.py (1)
48-54: Sparse-attn priority may preempt hybrid managersChanging selection to “if sparse_attn_config is not None: … return sparse manager” puts sparse attention ahead of Nemotron/Qwen3‑Next hybrid paths. If any hybrid model is constructed with a non‑None sparse config, it will no longer select MambaHybridCacheManager.
- If that’s intended: add a short comment documenting precedence.
- If not: reorder checks or gate sparse path to non‑hybrid models.
Would you confirm whether hybrid configs are guaranteed to set sparse_attn_config=None?
tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
2291-2308: Remove unused fp8kv param for DeepSeek‑V3.2 baselineBaseline runs are FP8 GEMM with BF16 KV; the
fp8kvparameter is always False and only adds noise.Apply:
@@ - @pytest.mark.parametrize( - "tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend", + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,mtp_nextn,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend", @@ - def test_baseline_fp8gemm_bf16_kv(self, tp_size, pp_size, ep_size, - mtp_nextn, fp8kv, attention_dp, + def test_baseline_fp8gemm_bf16_kv(self, tp_size, pp_size, ep_size, + mtp_nextn, attention_dp, cuda_graph, overlap_scheduler, max_batch_size, moe_backend): @@ - if fp8kv: - kv_cache_config.dtype = "fp8" + # BF16 KV cache baseline (do not set fp8 KV here)Optionally, use
self.MODEL_PATHin theLLM(...)call to avoid repeating the path literal.Also applies to: 2332-2334
tensorrt_llm/_torch/attention_backend/sparse/kernel.py (1)
373-411: Type/style nits in Python wrapper.
- Use Optional[int] for stride_factor typing.
- Avoid unused var warning for num_requests.
- Replace ambiguous × with 'x'.
-def triton_convert_req_index_to_global_index( +def triton_convert_req_index_to_global_index( req_id: torch.Tensor, # int32 [num_tokens] - block_table: torch. - Tensor, # int32 [num_requests, max_num_blocks_per_req] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] BLOCK_SIZE: int, NUM_TOPK_TOKENS: int = 2048, BLOCK_N: int = 128, # tile width along columns - stride_factor: - int = None, # for strided memory layout (with layer interleaving), defaults to BLOCK_SIZE + stride_factor: Optional[int] = None, # for strided memory layout (with layer interleaving) layer_id: int = 0, # for layer interleaving layout ): @@ - # Exact 2D grid: tokens × column tiles + # Exact 2D grid: tokens x column tiles grid = (num_tokens, tiles_per_row) @@ - num_tokens = req_id.shape[0] - num_requests, max_num_blocks_per_req = block_table.shape + num_tokens = req_id.shape[0] + _num_requests, max_num_blocks_per_req = block_table.shapeAlso applies to: 427-429
tests/unittest/_torch/attention/sparse/test_dsa_indexer.py (2)
600-606: Remove f-prefix from strings without placeholders.Cleanup f-strings that have no interpolation.
-print(f"\n=== Test Config ===") +print("\n=== Test Config ===") @@ -print(f"\n=== Phase 1: Context (variable tokens/seq) ===") +print("\n=== Phase 1: Context (variable tokens/seq) ===") @@ -print(f"\n=== Kernel Execution ===") +print("\n=== Kernel Execution ===") @@ -print(f"\n=== Reference Computation ===") +print("\n=== Reference Computation ===") @@ -print(f"\n=== Validation ===") +print("\n=== Validation ===") @@ - print(f"✅ test_split_prefill_chunks passed") + print("✅ test_split_prefill_chunks passed") @@ - print( + print( f" Request {req_idx} (len={seq_lens_list[req_idx]}): " f"{len(req_mismatches)} mismatches") @@ - print(f"\n=== Chunked Path ===") + print("\n=== Chunked Path ===") @@ - print(f"\n=== Non-chunked Baseline ===") + print("\n=== Non-chunked Baseline ===") @@ - print(f"\n=== Validation ===") + print("\n=== Validation ===")Also applies to: 640-660, 687-697, 699-706, 743-747, 856-857, 929-935, 973-981, 1010-1040, 1043-1048
969-971: Silence unused variable in tests.k_fp8 isn’t used in this test path; prefix with underscore to avoid lint noise.
- k_fp8, k_scale = torch.ops.trtllm.fp8_quantize_1x128(k) + _k_fp8, k_scale = torch.ops.trtllm.fp8_quantize_1x128(k)tensorrt_llm/llmapi/llm_args.py (1)
174-193: Dispatch for sparse_attention_config looks correct; minor UX improvement.Consider a clearer error for missing/unknown algorithm and keep ‘algorithm’ out of subclass kwargs (you already strip it). LGTM otherwise.
- if algorithm is None: - raise ValueError(f"Sparse attention algorithm is required") + if algorithm is None: + raise ValueError("sparse_attention_config.algorithm is required (e.g., 'rocket' or 'dsa')")examples/llm-api/llm_sparse_attention.py (2)
47-71: CLI: add DSA to docstring and help strings.Update header comment to list DSA in Supported algorithms to avoid confusion.
191-195: DSA run path: propagate additional indexer fields when needed.If you later expose index_n_heads/index_head_dim/index_topk, thread them here via DSASparseAttentionConfig.
tensorrt_llm/_torch/pyexecutor/resource_manager.py (2)
1184-1188: Use explicit exception instead of assert for allocation exhaustionAsserts can be stripped; raise a clear RuntimeError for deterministic behavior.
def _allocate_blocks(self, block_count: int) -> list: - assert len(self.free_blocks) >= block_count, "Not enough blocks." + if len(self.free_blocks) < block_count: + raise RuntimeError(f"Not enough blocks: requested {block_count}, available {len(self.free_blocks)}") blocks = [self.free_blocks.popleft() for _ in range(block_count)] return blocks
1162-1167: free_resources lacks guard for unknown request_idfree_resources assumes request_id exists; add a fast-return guard to avoid KeyError on double-free/late cleanup.
def free_resources(self, request: LlmRequest): request_id = request.py_request_id - self._free_blocks(self.block_ids[request_id]) - del self.block_ids[request_id] - del self.num_sequences[request_id] + if request_id not in self.block_ids: + return + self._free_blocks(self.block_ids[request_id]) + self.block_ids.pop(request_id, None) + self.num_sequences.pop(request_id, None)cpp/tensorrt_llm/flash_mla/CMakeLists.txt (2)
147-151: Avoid global LTO disable; scope to targetSetting CMAKE_INTERPROCEDURAL_OPTIMIZATION globally can degrade the rest of the build.
-# Disable LTO before creating target (similar to DeepEP) Let CMake generate -# fatbinData for CUDA separable compilation -set(CMAKE_INTERPROCEDURAL_OPTIMIZATION FALSE) +# Disable LTO only for this target (similar to DeepEP) +# Let CMake generate fatbinData for CUDA separable compilation +set_property(TARGET flash_mla_cpp_tllm PROPERTY INTERPROCEDURAL_OPTIMIZATION FALSE)Note: move after pybind11_add_module so target exists.
83-87: Import rewrite may not cover all forms (“from flash_mla.cuda import …”)Single replace of “flash_mla.cuda” to “tensorrt_llm.flash_mla_cpp_tllm” might miss variations or create broken imports depending on module layout/install path.
- Add an additional replace to cover “from flash_mla.cuda import …” and “import flash_mla.cuda as …”.
- Ensure the extension is packaged under tensorrt_llm/ so “tensorrt_llm.flash_mla_cpp_tllm” actually resolves.
Example:
string(REGEX REPLACE "from[ ]+flash_mla[.]cuda[ ]+import" "from tensorrt_llm.flash_mla_cpp_tllm import" _content "${_content}") string(REGEX REPLACE "import[ ]+flash_mla[.]cuda[ ]+as" "import tensorrt_llm.flash_mla_cpp_tllm as" _content "${_content}")Also confirm install location of the built .so aligns with the new import path.
tensorrt_llm/_torch/attention_backend/sparse/rocket.py (1)
933-936: copy_kt_block_offsets returns None; update the signature or return the tensorThe wrapper method has no return; keep API void to avoid confusion.
- def copy_kt_block_offsets(self, request_ids: List[int], - block_offsets: torch.Tensor) -> torch.Tensor: + def copy_kt_block_offsets(self, request_ids: List[int], + block_offsets: torch.Tensor) -> None: self.kt_cache_manager.copy_block_offsets(request_ids, block_offsets)tensorrt_llm/evaluate/lm_eval.py (1)
66-83: Route through trtllm_apply_chat_template to pass “thinking” flag consistentlyDirectly calling tokenizer.apply_chat_template may not accept “thinking”; the TRT-LLM helper resolves chat template and forwards kwargs safely (tools, processor, exceptions).
- def apply_chat_template(self, - chat_history: List[Dict[str, str]], - add_generation_prompt: bool = True) -> str: + def apply_chat_template(self, + chat_history: List[Dict[str, str]], + add_generation_prompt: bool = True) -> str: @@ - return self.llm.tokenizer.apply_chat_template( - chat_history, - tokenize=False, - add_generation_prompt=add_generation_prompt, - **chat_template_kwargs, - ) + return trtllm_apply_chat_template( + model_type=getattr(self.llm, "_model_type", "auto"), + tokenizer=self.llm.tokenizer, + processor=getattr(self.llm, "input_processor", None), + conversation=chat_history, + add_generation_prompt=add_generation_prompt, + mm_placeholder_counts={}, # text-only + tools=None, + chat_template_kwargs=chat_template_kwargs, + )tensorrt_llm/_torch/modules/attention.py (2)
1635-1646: Unused args in forward_sparse_mla_kvcache_bf16compressed_kv and k_pe are unused; mark as unused to silence linters or remove from signature if not needed.
- def forward_sparse_mla_kvcache_bf16( - self, - q: torch.Tensor, - compressed_kv: torch.Tensor, - k_pe: torch.Tensor, + def forward_sparse_mla_kvcache_bf16( + self, + q: torch.Tensor, + _compressed_kv: torch.Tensor, + _k_pe: torch.Tensor,
1736-1742: Improve missing-kernel error message and early guardProvide concise actionable error.
- else: - raise RuntimeError( - "flash_mla_sparse_fwd not available. Please ensure FlashMLA module is built." - ) + else: + raise RuntimeError("FlashMLA sparse kernel missing: build and load flash_mla_cpp_tllm extension.")tensorrt_llm/_torch/attention_backend/sparse/dsa.py (2)
756-778: Rename unused loop variable to indicate it's intentionally unused.The loop variable
req_idis not used within the loop body. Rename it to_or_req_idto follow Python conventions for unused variables.- for req_idx, req_id in enumerate(request_ids): + for req_idx, _ in enumerate(request_ids): num_tokens = seq_lens[req_idx].item()Based on coding guidelines.
1151-1160: Add docstring explaining stub implementation.The
sparse_kv_predictmethod returns(None, None)with all parameters unused. Add a docstring explaining this is a stub for interface compatibility.def sparse_kv_predict( self, q: torch.Tensor, k: Optional[torch.Tensor], metadata: DSAtrtllmAttentionMetadata, hidden_states: Optional[torch.Tensor] = None, qr: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Stub implementation for interface compatibility. DSA does not use sparse KV prediction.""" return None, None
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (44)
.gitignore(1 hunks).gitmodules(1 hunks)3rdparty/DeepGEMM(1 hunks)3rdparty/flash-mla(1 hunks)cpp/CMakeLists.txt(1 hunks)cpp/tensorrt_llm/CMakeLists.txt(1 hunks)cpp/tensorrt_llm/flash_mla/CMakeLists.txt(1 hunks)cpp/tensorrt_llm/flash_mla/flash_mla_cpp_tllm.version(1 hunks)examples/llm-api/llm_sparse_attention.py(7 hunks)examples/llm-api/quickstart_advanced.py(2 hunks)examples/longbench/eval_longbench_v1.py(5 hunks)requirements.txt(1 hunks)scripts/build_wheel.py(4 hunks)setup.py(1 hunks)tensorrt_llm/_torch/attention_backend/interface.py(2 hunks)tensorrt_llm/_torch/attention_backend/sparse/dsa.py(1 hunks)tensorrt_llm/_torch/attention_backend/sparse/kernel.py(1 hunks)tensorrt_llm/_torch/attention_backend/sparse/rocket.py(5 hunks)tensorrt_llm/_torch/attention_backend/sparse/utils.py(3 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py(3 hunks)tensorrt_llm/_torch/attention_backend/utils.py(4 hunks)tensorrt_llm/_torch/configs/__init__.py(1 hunks)tensorrt_llm/_torch/configs/deepseek_v3.py(1 hunks)tensorrt_llm/_torch/model_config.py(3 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(3 hunks)tensorrt_llm/_torch/modules/attention.py(11 hunks)tensorrt_llm/_torch/modules/layer_norm.py(1 hunks)tensorrt_llm/_torch/pyexecutor/_util.py(3 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(2 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py(1 hunks)tensorrt_llm/_torch/pyexecutor/resource_manager.py(2 hunks)tensorrt_llm/_torch/speculative/mtp.py(1 hunks)tensorrt_llm/evaluate/lm_eval.py(9 hunks)tensorrt_llm/llmapi/__init__.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(6 hunks)tests/integration/defs/accuracy/references/gpqa_diamond.yaml(1 hunks)tests/integration/defs/accuracy/references/gsm8k.yaml(1 hunks)tests/integration/defs/accuracy/references/mmlu.yaml(1 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py(1 hunks)tests/integration/test_lists/qa/llm_function_core.txt(1 hunks)tests/integration/test_lists/qa/llm_function_core_sanity.txt(1 hunks)tests/unittest/_torch/attention/sparse/test_dsa_indexer.py(1 hunks)tests/unittest/_torch/attention/sparse/test_flash_mla.py(1 hunks)tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/configs/__init__.pyexamples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/configs/deepseek_v3.pytensorrt_llm/llmapi/__init__.pysetup.pytensorrt_llm/_torch/modules/layer_norm.pyscripts/build_wheel.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytensorrt_llm/_torch/attention_backend/sparse/utils.pytests/unittest/_torch/attention/sparse/test_sparse_mla_forward.pytensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/sparse/kernel.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/attention_backend/utils.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/model_engine.pyexamples/llm-api/llm_sparse_attention.pytensorrt_llm/_torch/attention_backend/sparse/rocket.pytests/unittest/_torch/attention/sparse/test_dsa_indexer.pytensorrt_llm/_torch/models/modeling_deepseekv3.pytests/unittest/_torch/attention/sparse/test_flash_mla.pytensorrt_llm/_torch/pyexecutor/resource_manager.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/evaluate/lm_eval.pyexamples/longbench/eval_longbench_v1.pytensorrt_llm/_torch/modules/attention.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/configs/__init__.pyexamples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/configs/deepseek_v3.pytensorrt_llm/llmapi/__init__.pysetup.pytensorrt_llm/_torch/modules/layer_norm.pyscripts/build_wheel.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytensorrt_llm/_torch/attention_backend/sparse/utils.pytests/unittest/_torch/attention/sparse/test_sparse_mla_forward.pytensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/sparse/kernel.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/attention_backend/utils.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/model_engine.pyexamples/llm-api/llm_sparse_attention.pytensorrt_llm/_torch/attention_backend/sparse/rocket.pytests/unittest/_torch/attention/sparse/test_dsa_indexer.pytensorrt_llm/_torch/models/modeling_deepseekv3.pytests/unittest/_torch/attention/sparse/test_flash_mla.pytensorrt_llm/_torch/pyexecutor/resource_manager.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/evaluate/lm_eval.pyexamples/longbench/eval_longbench_v1.pytensorrt_llm/_torch/modules/attention.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/configs/__init__.pyexamples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/configs/deepseek_v3.pytensorrt_llm/llmapi/__init__.pysetup.pytensorrt_llm/_torch/modules/layer_norm.pyscripts/build_wheel.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytensorrt_llm/_torch/attention_backend/sparse/utils.pytests/unittest/_torch/attention/sparse/test_sparse_mla_forward.pytensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/sparse/kernel.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/attention_backend/utils.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/model_engine.pyexamples/llm-api/llm_sparse_attention.pytensorrt_llm/_torch/attention_backend/sparse/rocket.pytests/unittest/_torch/attention/sparse/test_dsa_indexer.pytensorrt_llm/_torch/models/modeling_deepseekv3.pytests/unittest/_torch/attention/sparse/test_flash_mla.pytensorrt_llm/_torch/pyexecutor/resource_manager.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/evaluate/lm_eval.pyexamples/longbench/eval_longbench_v1.pytensorrt_llm/_torch/modules/attention.py
🧠 Learnings (4)
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
PR: NVIDIA/TensorRT-LLM#6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
Applied to files:
examples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/attention_backend/sparse/rocket.py
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
PR: NVIDIA/TensorRT-LLM#7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.
Applied to files:
tests/integration/test_lists/qa/llm_function_core.txttests/integration/test_lists/qa/llm_function_core_sanity.txtexamples/longbench/eval_longbench_v1.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
.gitignore
📚 Learning: 2025-08-15T06:46:54.897Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:54.897Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp addToken function, newly allocated blocks are unshared by design. The beam search path in addToken (when sequence.getNumTokens() > windowSize) is currently broken/non-functional with SWA, so the block allocation doesn't follow a shared-then-unshared pattern.
Applied to files:
tensorrt_llm/_torch/pyexecutor/resource_manager.py
🧬 Code graph analysis (23)
tensorrt_llm/_torch/configs/__init__.py (1)
tensorrt_llm/_torch/configs/deepseek_v3.py (1)
DeepseekV3Config(9-101)
tensorrt_llm/_torch/configs/deepseek_v3.py (2)
tensorrt_llm/models/modeling_utils.py (1)
PretrainedConfig(369-570)cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h (3)
intermediate_size(238-238)n_group(234-234)topk_group(236-236)
tensorrt_llm/llmapi/__init__.py (1)
tensorrt_llm/llmapi/llm_args.py (1)
DSASparseAttentionConfig(228-247)
tests/integration/defs/accuracy/test_llm_api_pytorch.py (3)
tests/integration/defs/accuracy/accuracy_core.py (5)
LlmapiAccuracyTestHarness(846-857)MMLU(317-331)evaluate(184-247)evaluate(765-775)GSM8K(334-349)tests/integration/defs/conftest.py (2)
llm_models_root(79-93)get_sm_version(1889-1892)tensorrt_llm/llmapi/llm_args.py (4)
MoeConfig(250-284)KvCacheConfig(1224-1358)CudaGraphConfig(109-166)MTPDecodingConfig(784-828)
tensorrt_llm/_torch/attention_backend/sparse/utils.py (1)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (2)
DSACacheManager(1220-1433)DSATrtllmAttention(1099-1217)
tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py (8)
tensorrt_llm/_torch/attention_backend/interface.py (2)
PositionalEmbeddingParams(512-530)RopeParams(356-508)tensorrt_llm/_torch/attention_backend/sparse/dsa.py (5)
DSACacheManager(1220-1433)add_dummy_requests(1292-1325)prepare(393-490)prepare(678-783)mla_rope_append_paged_kv_assign_q(1162-1217)tensorrt_llm/_torch/attention_backend/utils.py (1)
get_attention_backend(15-37)tensorrt_llm/_torch/metadata.py (1)
KVCacheParams(9-31)tensorrt_llm/_torch/modules/attention.py (3)
MLA(648-1801)forward_context_dsa(1244-1262)forward_generation_dsa(1264-1282)tensorrt_llm/_utils.py (2)
str_dtype_to_binding(221-224)torch_dtype_to_str(230-231)tensorrt_llm/llmapi/llm_args.py (2)
KvCacheConfig(1224-1358)DSASparseAttentionConfig(228-247)tensorrt_llm/functional.py (5)
RopeEmbeddingUtils(4677-5227)chunk(3826-3861)sum(3253-3275)create_sinusoidal_positions_yarn(4892-4968)arange(1498-1569)
tensorrt_llm/_torch/speculative/mtp.py (3)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
attn_metadata(121-122)tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)
update_for_spec_dec(492-518)tensorrt_llm/_torch/attention_backend/interface.py (1)
update_for_spec_dec(339-342)
tensorrt_llm/_torch/model_config.py (2)
tensorrt_llm/llmapi/llm_args.py (1)
DSASparseAttentionConfig(228-247)tensorrt_llm/models/automodel.py (1)
AutoConfig(10-49)
tensorrt_llm/_torch/attention_backend/trtllm.py (2)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)
sparse_attn_predict(1139-1149)tensorrt_llm/_torch/attention_backend/sparse/rocket.py (1)
sparse_attn_predict(210-287)
tensorrt_llm/_torch/attention_backend/interface.py (2)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)
update_for_spec_dec(492-518)cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h (1)
hidden_size(201-201)
tensorrt_llm/llmapi/llm_args.py (4)
tensorrt_llm/builder.py (1)
from_dict(610-715)tensorrt_llm/_utils.py (1)
from_dict(805-816)tensorrt_llm/mapping.py (1)
from_dict(325-326)tensorrt_llm/models/modeling_utils.py (3)
from_dict(253-263)from_dict(325-334)from_dict(487-492)
tensorrt_llm/_torch/attention_backend/utils.py (1)
cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h (1)
hidden_size(201-201)
tensorrt_llm/_torch/pyexecutor/_util.py (4)
tensorrt_llm/_torch/attention_backend/sparse/utils.py (1)
get_sparse_attn_kv_cache_manager(6-15)tensorrt_llm/_torch/pyexecutor/config_utils.py (2)
is_nemotron_hybrid(1-6)is_qwen3_next(16-20)tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py (1)
MambaHybridCacheManager(167-246)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
KVCacheManager(144-1069)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py (1)
model(162-165)
examples/llm-api/llm_sparse_attention.py (1)
tensorrt_llm/llmapi/llm_args.py (5)
CudaGraphConfig(109-166)DSASparseAttentionConfig(228-247)KvCacheConfig(1224-1358)MoeConfig(250-284)RocketSparseAttentionConfig(205-225)
tensorrt_llm/_torch/attention_backend/sparse/rocket.py (1)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (9)
BlockManager(1118-1190)add_tokens(1134-1146)copy_block_offsets(1148-1156)rewind_cache(1168-1182)free_resources(80-81)free_resources(529-531)free_resources(1162-1166)free_resources(1220-1223)free_resources(1345-1346)
tests/unittest/_torch/attention/sparse/test_dsa_indexer.py (4)
tensorrt_llm/_torch/attention_backend/interface.py (11)
PositionalEmbeddingParams(512-530)RopeParams(356-508)MLAParams(643-651)forward(600-623)num_contexts(196-197)num_contexts(200-203)seq_lens(168-169)seq_lens(172-193)num_ctx_tokens(264-265)num_tokens(268-269)prepare(271-274)tensorrt_llm/_torch/attention_backend/sparse/dsa.py (11)
DSACacheManager(1220-1433)Indexer(521-1096)compute_cu_seqlen_kv_bounds_nocache(186-228)split_prefill_chunks(109-183)forward(1050-1096)copy_indexer_k_cache_offsets(1327-1331)add_dummy_requests(1292-1325)prepare(393-490)prepare(678-783)_update_k_cache(785-833)get_indexer_k_cache_buffers(1333-1338)tensorrt_llm/llmapi/llm_args.py (3)
KvCacheConfig(1224-1358)world_size(337-349)world_size(358-365)tensorrt_llm/mapping.py (1)
Mapping(348-507)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (3)
tensorrt_llm/llmapi/llm_args.py (2)
quant_config(2577-2580)quant_config(2583-2584)tensorrt_llm/_torch/models/modeling_utils.py (2)
register_auto_model(617-623)load_weights(557-575)tensorrt_llm/_torch/models/modeling_speculative.py (3)
load_weights(301-316)load_weights(441-445)load_weights(598-603)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (2)
tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py (2)
free_resources(139-143)free_resources(240-242)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest(421-613)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (4)
tensorrt_llm/_torch/attention_backend/trtllm.py (4)
TrtllmAttention(1119-1558)TrtllmAttentionMetadata(555-1116)tokens_per_block(609-613)get_local_layer_idx(1202-1206)tensorrt_llm/_torch/pyexecutor/resource_manager.py (12)
BlockManager(1118-1190)KVCacheManager(144-1069)get_unique_primary_pool(735-736)get_batch_cache_indices(691-705)add_tokens(1134-1146)copy_block_offsets(1148-1156)rewind_cache(1168-1182)free_resources(80-81)free_resources(529-531)free_resources(1162-1166)free_resources(1220-1223)free_resources(1345-1346)tensorrt_llm/_torch/attention_backend/sparse/kernel.py (1)
triton_convert_req_index_to_global_index(373-449)tensorrt_llm/quantization/utils/fp8_utils.py (1)
per_token_quant_and_transform(447-520)
tensorrt_llm/evaluate/lm_eval.py (2)
tensorrt_llm/llmapi/llm.py (2)
tokenizer(726-730)tokenizer(733-734)tensorrt_llm/inputs/utils.py (1)
apply_chat_template(408-447)
examples/longbench/eval_longbench_v1.py (2)
tensorrt_llm/llmapi/llm_args.py (4)
CudaGraphConfig(109-166)KvCacheConfig(1224-1358)MoeConfig(250-284)RocketSparseAttentionConfig(205-225)tensorrt_llm/_torch/distributed/communicator.py (1)
tp_size(63-64)
tensorrt_llm/_torch/modules/attention.py (7)
tensorrt_llm/_utils.py (8)
get_sm_version(733-735)is_sm_100f(739-742)nvtx_range(904-923)nvtx_range_debug(926-950)dtype(998-999)dtype(1006-1016)shape(1002-1003)shape(1019-1020)tensorrt_llm/_torch/attention_backend/sparse/dsa.py (3)
DSAtrtllmAttentionMetadata(254-518)transform_local_topk_and_prepare_pool_view(47-106)mla_rope_append_paged_kv_assign_q(1162-1217)tensorrt_llm/_torch/attention_backend/utils.py (1)
create_attention(40-100)tensorrt_llm/_torch/attention_backend/interface.py (9)
support_fused_rope(626-627)update_quant_config(592-598)AttentionMetadata(40-342)num_contexts(196-197)num_contexts(200-203)num_generations(206-207)num_generations(210-213)num_ctx_tokens(264-265)num_tokens(268-269)tensorrt_llm/_torch/modules/multi_stream_utils.py (1)
maybe_execute_in_parallel(35-74)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
bmm_out(23-24)cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp (2)
fp8_block_scaling_bmm_out(267-326)fp8_block_scaling_bmm_out(267-268)
🪛 Ruff (0.14.0)
tensorrt_llm/_torch/configs/deepseek_v3.py
12-12: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py
151-151: Unused function argument: kv_lora_rank
(ARG001)
153-153: Unused function argument: v_head_dim
(ARG001)
193-193: Unused function argument: kv_lora_rank
(ARG001)
194-194: Unused function argument: v_head_dim
(ARG001)
195-195: Unused function argument: device
(ARG001)
231-231: Unused function argument: query_lens
(ARG001)
296-296: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
299-299: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
442-442: f-string without any placeholders
Remove extraneous f prefix
(F541)
505-505: f-string without any placeholders
Remove extraneous f prefix
(F541)
594-594: f-string without any placeholders
Remove extraneous f prefix
(F541)
698-698: Loop control variable local_pos not used within loop body
Rename unused local_pos to _local_pos
(B007)
715-715: Local variable expected_ctx_indices is assigned to but never used
Remove assignment to unused variable expected_ctx_indices
(F841)
738-738: Local variable expected_gen_indices is assigned to but never used
Remove assignment to unused variable expected_gen_indices
(F841)
tensorrt_llm/_torch/attention_backend/trtllm.py
1388-1388: Unused method argument: kwargs
(ARG002)
1477-1477: Unused method argument: kwargs
(ARG002)
tensorrt_llm/_torch/attention_backend/sparse/kernel.py
382-382: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
413-413: Unpacked variable num_requests is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
427-427: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
tensorrt_llm/_torch/attention_backend/utils.py
59-59: Undefined name SparseAttentionConfig
(F821)
tests/unittest/_torch/attention/sparse/test_dsa_indexer.py
35-35: Do not use bare except
(E722)
119-119: Unused method argument: head_dim
(ARG002)
133-133: Unused lambda argument: pos_ids
(ARG005)
600-600: f-string without any placeholders
Remove extraneous f prefix
(F541)
640-640: f-string without any placeholders
Remove extraneous f prefix
(F541)
687-687: f-string without any placeholders
Remove extraneous f prefix
(F541)
699-699: f-string without any placeholders
Remove extraneous f prefix
(F541)
743-743: f-string without any placeholders
Remove extraneous f prefix
(F541)
856-856: f-string without any placeholders
Remove extraneous f prefix
(F541)
907-908: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
929-929: f-string without any placeholders
Remove extraneous f prefix
(F541)
969-969: Unpacked variable k_fp8 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
973-973: f-string without any placeholders
Remove extraneous f prefix
(F541)
1010-1010: f-string without any placeholders
Remove extraneous f prefix
(F541)
1043-1043: f-string without any placeholders
Remove extraneous f prefix
(F541)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py
75-75: Unpacked variable num_blocks is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
528-528: Undefined name SparseAttentionConfig
(F821)
756-756: Loop control variable req_id not used within loop body
Rename unused req_id to _req_id
(B007)
1114-1114: Undefined name SparseAttentionConfig
(F821)
1147-1147: Unused method argument: kwargs
(ARG002)
1153-1153: Unused method argument: q
(ARG002)
1154-1154: Unused method argument: k
(ARG002)
1155-1155: Unused method argument: metadata
(ARG002)
1156-1156: Unused method argument: hidden_states
(ARG002)
1157-1157: Unused method argument: qr
(ARG002)
1158-1158: Unused method argument: kwargs
(ARG002)
1168-1168: Unused method argument: kwargs
(ARG002)
1237-1237: Undefined name DecodingBaseConfig
(F821)
1242-1242: Undefined name SparseAttentionConfig
(F821)
1424-1424: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/modules/attention.py
1639-1639: Unused method argument: compressed_kv
(ARG002)
1640-1640: Unused method argument: k_pe
(ARG002)
1740-1742: Avoid specifying long messages outside the exception class
(TRY003)
tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py
Outdated
Show resolved
Hide resolved
Superjomn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM on the llmapi changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds DeepSeek-V3.2 support with FP8/NVFP4 quantization and BF16 KV cache using a new DSA (DeepSeek Sparse Attention) backend. The implementation includes MTP support (next_n=1), CUDA Graph support, and indexer chunked prefill for long-context requests.
Key changes:
- Introduces FlashMLA kernel integration for sparse attention operations
- Implements DSA sparse attention backend with indexer for token selection
- Adds DeepSeek-V3.2 model configuration and weight loading support
Reviewed Changes
Copilot reviewed 42 out of 44 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py | Comprehensive test suite for sparse MLA attention with BF16 KV cache |
| tests/unittest/_torch/attention/sparse/test_flash_mla.py | Basic FlashMLA kernel validation tests |
| tests/unittest/_torch/attention/sparse/test_dsa_indexer.py | Tests for indexer kernels, paged KV cache, and chunked prefill |
| tests/integration/test_lists/qa/*.txt | Added DeepSeek-V3.2 integration test cases |
| tests/integration/defs/accuracy/test_llm_api_pytorch.py | Integration tests for DeepSeek-V3.2 accuracy validation |
| tests/integration/defs/accuracy/references/*.yaml | Reference accuracy benchmarks for MMLU, GSM8K, GPQA |
| tensorrt_llm/llmapi/llm_args.py | DSASparseAttentionConfig class and sparse config dispatcher |
| tensorrt_llm/llmapi/init.py | Exported DSASparseAttentionConfig |
| tensorrt_llm/evaluate/lm_eval.py | Added enable_thinking flag for reasoning models |
| tensorrt_llm/_torch/speculative/mtp.py | Metadata update hooks for spec-dec |
| tensorrt_llm/_torch/pyexecutor/*.py | Sparse attention cache manager integration |
| tensorrt_llm/_torch/modules/layer_norm.py | Fixed LayerNorm dtype casting |
| tensorrt_llm/_torch/modules/attention.py | DSA forward pass implementation for MLA |
| tensorrt_llm/_torch/models/modeling_deepseekv3.py | DeepSeek-V3.2 model registration and weight loading |
| tensorrt_llm/_torch/model_config.py | DeepSeek-V3.2 config registry and sparse config integration |
| tensorrt_llm/_torch/configs/deepseek_v3.py | DeepSeek-V3.2 model configuration class |
| tensorrt_llm/_torch/attention_backend/sparse/dsa.py | Complete DSA backend implementation with indexer and cache manager |
| tensorrt_llm/_torch/attention_backend/sparse/kernel.py | Triton kernel for index conversion |
| tensorrt_llm/_torch/attention_backend/sparse/utils.py | DSA backend registration |
| tensorrt_llm/_torch/attention_backend/sparse/rocket.py | Refactored BlockManager extraction |
| tensorrt_llm/_torch/attention_backend/trtllm.py | Added kwargs support for sparse attention hooks |
| tensorrt_llm/_torch/attention_backend/utils.py | Updated create_attention signature |
| tensorrt_llm/_torch/attention_backend/interface.py | Added update_for_spec_dec hook and hidden_size to MLAParams |
| setup.py, scripts/build_wheel.py | FlashMLA module packaging |
| examples/longbench/eval_longbench_v1.py | Enhanced CLI with parallelism and CUDA graph options |
| examples/llm-api/*.py | DSA algorithm support and configuration options |
| cpp/tensorrt_llm/flash_mla/* | FlashMLA C++ extension build configuration |
| 3rdparty/flash-mla, 3rdparty/DeepGEMM, .gitmodules | Added FlashMLA submodule and updated DeepGEMM |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py
Outdated
Show resolved
Hide resolved
tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py
Outdated
Show resolved
Hide resolved
|
/bot run |
|
PR_Github #21966 [ run ] triggered by Bot. Commit: |
|
PR_Github #21966 [ run ] completed with state |
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Problem: CI was failing when installing requirements-dev.txt because fast-hadamard-transform (a git package built from source) has broken setup.py that imports wheel and torch at build-time. Pip collects metadata for all packages before installation, causing it to run fast-hadamard-transform's setup.py before wheel/torch are installed, resulting in ModuleNotFoundError. Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
…weight loading and fix some rebase bugs. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
|
PR_Github #22421 [ kill ] triggered by Bot. Commit: |
|
PR_Github #22421 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #22426 [ run ] triggered by Bot. Commit: |
|
PR_Github #22426 [ run ] completed with state |
|
/bot run |
|
PR_Github #22446 [ run ] triggered by Bot. Commit: |
|
PR_Github #22446 [ run ] completed with state |
…cache/NVFP4 + BF16 KV cache (NVIDIA#8405)" This reverts commit e47c787.
…cache/NVFP4 + BF16 KV cache (NVIDIA#8405)" This reverts commit e47c787.
…cache/NVFP4 + BF16 KV cache (NVIDIA#8405)" This reverts commit e47c787. Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
…FP4 + BF16 KV cache (NVIDIA#8405) Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>
…FP4 + BF16 KV cache (NVIDIA#8405) Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>
…FP4 + BF16 KV cache (NVIDIA#8405) Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>
…FP4 + BF16 KV cache (NVIDIA#8405) Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Tracin <10434017+Tracin@users.noreply.github.com>
Co-authored with @lfr-0531 @Tracin
This PR adds:
TODOs:
Run command:
With indexer chunked prefill (2048)
W/O indexer chunking:
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.