@@ -69,34 +69,17 @@ def transform_local_topk_and_prepare_pool_view(
6969 """
7070 assert topk_indices .dtype == torch .int32
7171
72- # Get KV cache pool: [num_blocks, 1, tokens_per_block, 1, head_dim]
73- kv_pool = kv_cache_manager .get_buffers (layer_idx )
74- num_blocks , _ , tokens_per_block , _ , head_dim = kv_pool .shape
75- assert kv_pool .shape [1 ] == 1 and kv_pool .shape [3 ] == 1
76-
77- # Squeeze to [num_blocks, tokens_per_block, head_dim]
78- kv_pool = kv_pool .squeeze (1 ).squeeze (2 )
79-
80- # Auto-detect stride and prepare view
81- if kv_pool .is_contiguous ():
82- stride_factor = tokens_per_block
83- kv_pool = kv_pool .view (- 1 , 1 , head_dim )
84- else :
85- # Here we simply do:
86- # kv_pool = kv_pool.reshape(-1, 1, head_dim) to make it contiguous
87- # however, using strided layout and directly offset topk tokens in the
88- # (layer-interleaved) pool MIGHT be (not benchmarked) more efficient as its zero-copy.
89-
90- # Strided layout: compute stride and create efficient view
91- block_stride = kv_pool .stride (0 )
92- token_stride = kv_pool .stride (1 )
93- assert token_stride == head_dim
94- stride_factor = block_stride // token_stride
95- view_size = (num_blocks - 1 ) * stride_factor + tokens_per_block
96- kv_pool = torch .as_strided (kv_pool ,
97- size = (view_size , 1 , head_dim ),
98- stride = (token_stride , 0 , 1 ),
99- storage_offset = kv_pool .storage_offset ())
72+ # Get all layer KV cache pool: [num_blocks, num_layers, kv_factor, blockSize]
73+ all_layer_kv_pool = kv_cache_manager .get_unique_primary_pool (
74+ ) # [num_blocks, num_layers, kv_factor, blockSize]
75+ num_blocks , num_layers , _ , _ = all_layer_kv_pool .shape
76+ tokens_per_block = kv_cache_manager .tokens_per_block
77+ head_dim = kv_cache_manager .head_dim
78+ assert num_layers == kv_cache_manager .num_local_layers , "PP is not enable yet for DS32"
79+ assert all_layer_kv_pool .is_contiguous (
80+ ), "all_layer_kv_pool should be contiguous"
81+ all_layer_kv_pool = all_layer_kv_pool .squeeze (2 ).view (- 1 , 1 , head_dim )
82+ stride_factor = num_layers * tokens_per_block
10083
10184 # Get block_table and request indices for this phase
10285 if is_generation :
@@ -114,12 +97,13 @@ def transform_local_topk_and_prepare_pool_view(
11497 req_idx ,
11598 block_table ,
11699 topk_indices ,
117- BLOCK_SIZE = attn_metadata . tokens_per_block ,
100+ BLOCK_SIZE = tokens_per_block ,
118101 NUM_TOPK_TOKENS = topk_indices .shape [1 ],
119102 stride_factor = stride_factor ,
103+ layer_id = layer_idx ,
120104 )
121105
122- return global_indices , kv_pool
106+ return global_indices , all_layer_kv_pool
123107
124108
125109def split_prefill_chunks (
@@ -1262,7 +1246,7 @@ def __init__(
12621246 assert not kv_cache_config .enable_block_reuse , "DSA cache requires block reuse to be disabled in KV cache config"
12631247 self .quant_block_size = 128
12641248 self .index_head_dim = sparse_attn_config .index_head_dim
1265- # Use a fixed tokens_per_block for indexer k cache
1249+ # Use a fixed tokens_per_block for indexer k cache due to DG kernel constraints
12661250 self .indexer_k_cache_tokens_per_block = 64
12671251
12681252 super ().__init__ (
0 commit comments