Skip to content

Commit d5f10fb

Browse files
committed
Clean-up and use all-layer kv cache pool to the flashmla kernel
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
1 parent dfb99bc commit d5f10fb

File tree

7 files changed

+41
-67
lines changed

7 files changed

+41
-67
lines changed

cpp/tensorrt_llm/flash_mla/CMakeLists.txt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
144144
${FLASH_MLA_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu)
145145
endif()
146146

147-
# Disable LTO before creating target (matching DeepEP's approach) Let CMake
148-
# generate fatbinData for CUDA separable compilation
147+
# Disable LTO before creating target (similar to DeepEP) Let CMake generate
148+
# fatbinData for CUDA separable compilation
149149
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION FALSE)
150150

151151
pybind11_add_module(flash_mla_cpp_tllm ${FLASH_MLA_SOURCES})
@@ -174,7 +174,7 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
174174
endif()
175175
set_cuda_architectures(flash_mla_cpp_tllm ${FLASH_MLA_BUILD_ARCHS})
176176

177-
# Compiler options matching FlashMLA setup.py
177+
# Copy of compiler options from FlashMLA setup.py
178178
target_compile_options(
179179
flash_mla_cpp_tllm
180180
PRIVATE
@@ -215,7 +215,11 @@ target_include_directories(
215215
# Link libraries (matching FlashMLA setup.py: cuda, cudart + torch)
216216
target_link_libraries(
217217
flash_mla_cpp_tllm PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIB}
218-
CUDA::cuda_driver CUDA::cudart CUDA::cudart_static)
218+
CUDA::cuda_driver CUDA::cudart)
219+
target_link_options(
220+
flash_mla_cpp_tllm PRIVATE
221+
-Wl,--version-script,${CMAKE_CURRENT_SOURCE_DIR}/flash_mla_cpp_tllm.version
222+
-Wl,--no-undefined-version)
219223

220224
# Link directories
221225
target_link_directories(

examples/llm-api/llm_sparse_attention.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,10 @@ def parse_arguments():
6464
type=int,
6565
default=2048,
6666
help="The prompt budget for RocketKV.")
67-
parser.add_argument('--index_n_heads',
67+
parser.add_argument('--index_max_chunk_size',
6868
type=int,
69-
default=64,
70-
help="The number of heads for the indexer.")
71-
parser.add_argument('--index_head_dim',
72-
type=int,
73-
default=128,
74-
help="The dimension of the indexer heads.")
75-
parser.add_argument('--index_topk',
76-
type=int,
77-
default=2048,
78-
help="The topk for the indexer.")
69+
default=32768,
70+
help="The maximum chunk size for the indexer.")
7971
parser.add_argument("--max_seq_len",
8072
type=int,
8173
default=8192,
@@ -198,11 +190,8 @@ def run_RocketKV(args):
198190

199191
def run_DSA(args):
200192
sparse_attention_config = DSASparseAttentionConfig(
201-
index_n_heads=args.index_n_heads,
202-
index_head_dim=args.index_head_dim,
203-
index_topk=args.index_topk,
204-
)
205-
run_llm(args, None)
193+
indexer_max_chunk_size=args.index_max_chunk_size, )
194+
run_llm(args, sparse_attention_config)
206195

207196

208197
def main():

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -69,34 +69,17 @@ def transform_local_topk_and_prepare_pool_view(
6969
"""
7070
assert topk_indices.dtype == torch.int32
7171

72-
# Get KV cache pool: [num_blocks, 1, tokens_per_block, 1, head_dim]
73-
kv_pool = kv_cache_manager.get_buffers(layer_idx)
74-
num_blocks, _, tokens_per_block, _, head_dim = kv_pool.shape
75-
assert kv_pool.shape[1] == 1 and kv_pool.shape[3] == 1
76-
77-
# Squeeze to [num_blocks, tokens_per_block, head_dim]
78-
kv_pool = kv_pool.squeeze(1).squeeze(2)
79-
80-
# Auto-detect stride and prepare view
81-
if kv_pool.is_contiguous():
82-
stride_factor = tokens_per_block
83-
kv_pool = kv_pool.view(-1, 1, head_dim)
84-
else:
85-
# Here we simply do:
86-
# kv_pool = kv_pool.reshape(-1, 1, head_dim) to make it contiguous
87-
# however, using strided layout and directly offset topk tokens in the
88-
# (layer-interleaved) pool MIGHT be (not benchmarked) more efficient as its zero-copy.
89-
90-
# Strided layout: compute stride and create efficient view
91-
block_stride = kv_pool.stride(0)
92-
token_stride = kv_pool.stride(1)
93-
assert token_stride == head_dim
94-
stride_factor = block_stride // token_stride
95-
view_size = (num_blocks - 1) * stride_factor + tokens_per_block
96-
kv_pool = torch.as_strided(kv_pool,
97-
size=(view_size, 1, head_dim),
98-
stride=(token_stride, 0, 1),
99-
storage_offset=kv_pool.storage_offset())
72+
# Get all layer KV cache pool: [num_blocks, num_layers, kv_factor, blockSize]
73+
all_layer_kv_pool = kv_cache_manager.get_unique_primary_pool(
74+
) # [num_blocks, num_layers, kv_factor, blockSize]
75+
num_blocks, num_layers, _, _ = all_layer_kv_pool.shape
76+
tokens_per_block = kv_cache_manager.tokens_per_block
77+
head_dim = kv_cache_manager.head_dim
78+
assert num_layers == kv_cache_manager.num_local_layers, "PP is not enable yet for DS32"
79+
assert all_layer_kv_pool.is_contiguous(
80+
), "all_layer_kv_pool should be contiguous"
81+
all_layer_kv_pool = all_layer_kv_pool.squeeze(2).view(-1, 1, head_dim)
82+
stride_factor = num_layers * tokens_per_block
10083

10184
# Get block_table and request indices for this phase
10285
if is_generation:
@@ -114,12 +97,13 @@ def transform_local_topk_and_prepare_pool_view(
11497
req_idx,
11598
block_table,
11699
topk_indices,
117-
BLOCK_SIZE=attn_metadata.tokens_per_block,
100+
BLOCK_SIZE=tokens_per_block,
118101
NUM_TOPK_TOKENS=topk_indices.shape[1],
119102
stride_factor=stride_factor,
103+
layer_id=layer_idx,
120104
)
121105

122-
return global_indices, kv_pool
106+
return global_indices, all_layer_kv_pool
123107

124108

125109
def split_prefill_chunks(
@@ -1262,7 +1246,7 @@ def __init__(
12621246
assert not kv_cache_config.enable_block_reuse, "DSA cache requires block reuse to be disabled in KV cache config"
12631247
self.quant_block_size = 128
12641248
self.index_head_dim = sparse_attn_config.index_head_dim
1265-
# Use a fixed tokens_per_block for indexer k cache
1249+
# Use a fixed tokens_per_block for indexer k cache due to DG kernel constraints
12661250
self.indexer_k_cache_tokens_per_block = 64
12671251

12681252
super().__init__(

tensorrt_llm/_torch/attention_backend/sparse/kernel.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,9 @@ def _convert_req_index_to_global_index_kernel_with_stride_factor(
318318
# shapes (compile-time where possible)
319319
max_num_blocks_per_req: tl.constexpr,
320320
BLOCK_SIZE: tl.constexpr,
321-
BLOCK_N: tl.constexpr, # tile width along columns
321+
BLOCK_N: tl.constexpr, # tile width along columns # strides (in elements)
322322
stride_factor: tl.constexpr, # for strided memory layout adjustment
323-
# strides (in elements)
323+
layer_id: tl.constexpr, # for layer interleaving layout
324324
bt_stride0,
325325
bt_stride1,
326326
ti_stride0,
@@ -352,7 +352,7 @@ def _convert_req_index_to_global_index_kernel_with_stride_factor(
352352

353353
# Compute block id and in-block offset
354354
block_id = tok // BLOCK_SIZE
355-
inblock_off = tok % BLOCK_SIZE
355+
inblock_off = tok % BLOCK_SIZE + layer_id * BLOCK_SIZE
356356

357357
# Guard block_table access
358358
valid_block = block_id < max_num_blocks_per_req
@@ -371,14 +371,16 @@ def _convert_req_index_to_global_index_kernel_with_stride_factor(
371371

372372

373373
def triton_convert_req_index_to_global_index(
374-
req_id: torch.Tensor, # int32 [num_tokens]
375-
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
376-
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
377-
BLOCK_SIZE: int = 64,
378-
NUM_TOPK_TOKENS: int = 2048,
379-
BLOCK_N: int = 128, # tile width along columns
380-
stride_factor:
374+
req_id: torch.Tensor, # int32 [num_tokens]
375+
block_table: torch.
376+
Tensor, # int32 [num_requests, max_num_blocks_per_req]
377+
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
378+
BLOCK_SIZE: int,
379+
NUM_TOPK_TOKENS: int = 2048,
380+
BLOCK_N: int = 128, # tile width along columns
381+
stride_factor:
381382
int = None, # for strided memory layout (with layer interleaving), defaults to BLOCK_SIZE
383+
layer_id: int = 0, # for layer interleaving layout
382384
):
383385
"""
384386
Convert request-local token indices to global KV cache pool indices.
@@ -436,6 +438,7 @@ def triton_convert_req_index_to_global_index(
436438
BLOCK_N,
437439
stride_factor,
438440
# strides
441+
layer_id,
439442
bt_stride0,
440443
bt_stride1,
441444
ti_stride0,

tensorrt_llm/_torch/modules/attention.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,7 +1530,6 @@ def forward_context(
15301530
elif trtllm_attention.has_cached_kv_for_mla_context(attn_metadata):
15311531
return self.forward_context_with_cached_kv(
15321532
q, latent_cache, attn_metadata, output)
1533-
15341533
return self.forward_context_default(q, compressed_kv, k_pe,
15351534
attn_metadata, output, latent_cache)
15361535

@@ -1601,9 +1600,6 @@ def forward_generation(
16011600
out_scale=self.out_scale,
16021601
latent_cache=latent_cache, # kvcache and k_pe
16031602
q_pe=q_pe, # used by `invokeMLARopeGeneration`
1604-
hidden_states=hidden_states,
1605-
qr=qr,
1606-
position_ids=position_ids,
16071603
)
16081604
fused_q = None
16091605

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ def _load_and_validate_config(
325325
if hasattr(config.pretrained_config, sub_config):
326326
getattr(config.pretrained_config,
327327
sub_config).num_hidden_layers = num_layers_override
328-
329328
return config
330329

331330
def _call_load_weights(self, load_method: Callable, weights, weight_mapper):

tensorrt_llm/quantization/utils/fp8_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ def _per_token_quant_and_transform_kernel(
444444
)
445445

446446

447-
# TODO: Add more comments and tests for this function for future reuse
448447
def per_token_quant_and_transform(
449448
input: torch.Tensor,
450449
quant_group_size: int = 128,

0 commit comments

Comments
 (0)