Skip to content

Commit 5b6a6ca

Browse files
suyogguptadominicshanshan
authored andcommitted
[NVIDIA#8272][feat] Enable chunked prefill for SSMs in AutoDeploy (NVIDIA#8477)
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
1 parent d9189f6 commit 5b6a6ca

File tree

17 files changed

+136
-90
lines changed

17 files changed

+136
-90
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,9 @@
88
for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
99
__all__.append(module_name)
1010
importlib.import_module(f"{__name__}.{module_name}")
11+
12+
# Recursively import subpackages and modules so their side-effects (e.g.,
13+
# op registrations) are applied even when nested in subdirectories.
14+
for _, full_name, _ in pkgutil.walk_packages(__path__, prefix=f"{__name__}."):
15+
__all__.append(full_name)
16+
importlib.import_module(full_name)

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/__init__.py

Whitespace-only changes.

tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py renamed to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
1919
from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
2020

21-
from ..utils.node_utils import extract_op_args
22-
from .attention_interface import (
21+
from ...utils.node_utils import extract_op_args
22+
from ..attention_interface import (
2323
AttentionDescriptor,
2424
AttentionLayout,
2525
AttentionRegistry,
@@ -74,8 +74,9 @@ def cuda_causal_conv_prepare_metadata(
7474
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)
7575

7676
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
77-
78-
return (seq_len_sanitized, seq_start, slot_idx_sanitized)
77+
# This is only used during prefill to determine if we should use the initial states from the cache.
78+
use_initial_states = input_pos > 0
79+
return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)
7980

8081

8182
@cuda_causal_conv_prepare_metadata.register_fake
@@ -88,6 +89,7 @@ def cuda_causal_conv_prepare_metadata_fake(
8889
torch.empty_like(seq_len_sanitized),
8990
torch.empty_like(seq_len_sanitized),
9091
torch.empty(num_seq, dtype=torch.long, device=slot_idx.device),
92+
torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device),
9193
)
9294

9395

@@ -101,6 +103,7 @@ def _cuda_cached_causal_conv1d(
101103
seq_len: torch.Tensor, # [num_seq]
102104
seq_start: torch.Tensor, # [num_seq]
103105
slot_idx: torch.Tensor, # [num_seq]
106+
use_initial_states: torch.Tensor, # [num_seq]
104107
# CACHES
105108
conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1]
106109
# CONSTANTS
@@ -161,7 +164,7 @@ def _cuda_cached_causal_conv1d(
161164
dim=0,
162165
).contiguous()
163166
cache_indices = slot_idx[:num_prefill].to(torch.int32).contiguous()
164-
has_initial_state = torch.zeros(num_prefill, dtype=torch.bool, device=input.device)
167+
has_initial_state = use_initial_states[:num_prefill].to(torch.bool)
165168

166169
# Run varlen conv; updates conv_state_cache in-place per cache_indices
167170
y_varlen = causal_conv1d_fn(
@@ -215,6 +218,7 @@ def _cuda_cached_causal_conv1d_fake(
215218
seq_len: torch.Tensor,
216219
seq_start: torch.Tensor,
217220
slot_idx: torch.Tensor,
221+
use_initial_states: torch.Tensor, # [num_seq]
218222
# CACHES
219223
conv_state_cache: torch.Tensor,
220224
# CONSTANTS
@@ -256,8 +260,8 @@ def get_cached_attention_op(cls) -> MHACallable:
256260

257261
@classmethod
258262
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
259-
# Returns (seq_len, seq_start, slot_idx)
260-
return torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata, 3
263+
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
264+
return torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata, 4
261265

262266
@classmethod
263267
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py renamed to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from torch._ops import OpOverloadPacket
1717
from torch.fx import Node
1818

19-
from ..utils.node_utils import extract_op_args
20-
from .attention_interface import (
19+
from ...utils.node_utils import extract_op_args
20+
from ..attention_interface import (
2121
AttentionDescriptor,
2222
AttentionLayout,
2323
AttentionRegistry,
@@ -160,8 +160,8 @@ def torch_causal_conv_prepare_metadata(
160160
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)
161161

162162
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
163-
164-
return (seq_len_sanitized, seq_start, slot_idx_sanitized)
163+
use_initial_states = input_pos > 0
164+
return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)
165165

166166

167167
@torch_causal_conv_prepare_metadata.register_fake
@@ -174,6 +174,7 @@ def torch_causal_conv_prepare_metadata_fake(
174174
torch.empty_like(seq_len_sanitized),
175175
torch.empty_like(seq_len_sanitized),
176176
torch.empty(num_seq, dtype=torch.long, device=slot_idx.device),
177+
torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device),
177178
)
178179

179180

@@ -187,6 +188,7 @@ def _torch_cached_causal_conv1d(
187188
seq_len: torch.Tensor, # [num_seq]
188189
seq_start: torch.Tensor, # [num_seq]
189190
slot_idx: torch.Tensor, # [num_seq]
191+
use_initial_states: torch.Tensor, # [num_seq]
190192
# CACHES
191193
conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k]
192194
# CONSTANTS
@@ -275,6 +277,7 @@ def _torch_cached_causal_conv1d_fake(
275277
seq_len: torch.Tensor,
276278
seq_start: torch.Tensor,
277279
slot_idx: torch.Tensor,
280+
use_initial_states: torch.Tensor, # [num_seq]
278281
# CACHES
279282
conv_state_cache: torch.Tensor,
280283
# CONSTANTS
@@ -317,8 +320,10 @@ def get_cached_attention_op(cls) -> MHACallable:
317320

318321
@classmethod
319322
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
323+
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
324+
# reference implementation to support chunked prefill.
320325
# Returns (seq_len, seq_start, slot_idx)
321-
return torch.ops.auto_deploy.torch_causal_conv_prepare_metadata, 3
326+
return torch.ops.auto_deploy.torch_causal_conv_prepare_metadata, 4
322327

323328
@classmethod
324329
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py renamed to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from torch._ops import OpOverloadPacket
1313
from torch.fx import Node
1414

15-
from ..utils.node_utils import extract_op_args
16-
from .attention_interface import (
15+
from ...utils.node_utils import extract_op_args
16+
from ..attention_interface import (
1717
AttentionDescriptor,
1818
AttentionLayout,
1919
AttentionRegistry,
@@ -25,10 +25,10 @@
2525
PrepareMetadataCallable,
2626
SequenceInfo,
2727
)
28-
from .torch_mamba import _torch_ssm_transform_prefill
28+
from .torch_mamba import _torch_ssm_prefill
2929

3030

31-
def _torch_cached_ssm_transform_decode(
31+
def _torch_cached_ssm_decode(
3232
hidden_states: torch.Tensor,
3333
A: torch.Tensor,
3434
B: torch.Tensor,
@@ -135,8 +135,10 @@ def _torch_ssm_prepare_metadata(
135135

136136
# Truncate slot indices to match active sequences
137137
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
138-
139-
return (seq_len_sanitized, seq_start, slot_idx_sanitized)
138+
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
139+
# reference implementation to support chunked prefill.
140+
use_initial_states = input_pos > 0
141+
return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)
140142

141143

142144
@_torch_ssm_prepare_metadata.register_fake
@@ -150,11 +152,12 @@ def _torch_ssm_prepare_metadata_fake(
150152
torch.empty_like(seq_len_sanitized),
151153
torch.empty_like(seq_len_sanitized),
152154
torch.empty(num_seq, dtype=torch.long, device=slot_idx.device),
155+
torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device),
153156
)
154157

155158

156-
@torch.library.custom_op("auto_deploy::torch_cached_ssm_transform", mutates_args={})
157-
def _torch_cached_ssm_transform(
159+
@torch.library.custom_op("auto_deploy::torch_cached_ssm", mutates_args={})
160+
def _torch_cached_ssm(
158161
# INPUTS (dense but may be flattened across sequences)
159162
hidden_states: torch.Tensor, # [b, s, num_heads, head_dim]
160163
A: torch.Tensor, # [num_heads]
@@ -167,6 +170,7 @@ def _torch_cached_ssm_transform(
167170
seq_len: torch.Tensor, # [num_seq]
168171
seq_start: torch.Tensor, # [num_seq]
169172
slot_idx: torch.Tensor, # [num_seq]
173+
use_initial_states: torch.Tensor, # [num_seq]
170174
# CACHES
171175
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
172176
# CONSTANTS
@@ -188,7 +192,7 @@ def _torch_cached_ssm_transform(
188192
slot_idx_long = slot_idx.to(torch.long)
189193
ssm_batch = ssm_state_cache.index_select(dim=0, index=slot_idx_long)
190194

191-
y, updated_state = _torch_cached_ssm_transform_decode(
195+
y, updated_state = _torch_cached_ssm_decode(
192196
hidden_states,
193197
A,
194198
B,
@@ -207,6 +211,14 @@ def _torch_cached_ssm_transform(
207211
# return in the same dtype as the input
208212
return y.to(hidden_states.dtype)
209213

214+
# Prefill
215+
if any(use_initial_states):
216+
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
217+
# reference implementation to support chunked prefill.
218+
raise ValueError(
219+
"torch mamba backend does not yet support chunked prefill "
220+
"and can not correctly handle initial states."
221+
)
210222
# Context/mixed phase (flattened sequences). Expect b == 1, but handle general b robustly.
211223
# We'll iterate over sequences defined by (seq_len, seq_start) and update state per slot.
212224
# Process across the flattened second dimension.
@@ -244,7 +256,7 @@ def _torch_cached_ssm_transform(
244256
dt_seq = dt_flat.index_select(0, idx_i).unsqueeze(0)
245257

246258
# Run prefill and obtain final SSM state for this sequence
247-
y_seq, ssm_state_i = _torch_ssm_transform_prefill(
259+
y_seq, ssm_state_i = _torch_ssm_prefill(
248260
hs_seq, A, B_seq, C_seq, D, dt_seq, dt_bias, time_step_limit, chunk_size
249261
)
250262

@@ -258,8 +270,8 @@ def _torch_cached_ssm_transform(
258270
return y
259271

260272

261-
@_torch_cached_ssm_transform.register_fake
262-
def _torch_cached_ssm_transform_fake(
273+
@_torch_cached_ssm.register_fake
274+
def _torch_cached_ssm_fake(
263275
# INPUTS
264276
hidden_states: torch.Tensor,
265277
A: torch.Tensor,
@@ -272,6 +284,7 @@ def _torch_cached_ssm_transform_fake(
272284
seq_len: torch.Tensor,
273285
seq_start: torch.Tensor,
274286
slot_idx: torch.Tensor,
287+
use_initial_states: torch.Tensor,
275288
# CACHES
276289
ssm_state_cache: torch.Tensor,
277290
# CONSTANTS
@@ -304,16 +317,16 @@ def get_num_qkv_args(cls) -> int:
304317

305318
@classmethod
306319
def get_source_attention_op(cls) -> OpOverloadPacket:
307-
return torch.ops.auto_deploy.torch_ssm_transform
320+
return torch.ops.auto_deploy.torch_ssm
308321

309322
@classmethod
310323
def get_cached_attention_op(cls) -> MHACallable:
311-
return torch.ops.auto_deploy.torch_cached_ssm_transform
324+
return torch.ops.auto_deploy.torch_cached_ssm
312325

313326
@classmethod
314327
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
315328
# Returns (seq_len, seq_start, slot_idx)
316-
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 3
329+
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4
317330

318331
@classmethod
319332
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py renamed to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _segment_sum(input_tensor):
6767
return tensor_segsum
6868

6969

70-
def _torch_ssm_transform_prefill(
70+
def _torch_ssm_prefill(
7171
hidden_states: torch.Tensor,
7272
A: torch.Tensor,
7373
B: torch.Tensor,
@@ -162,8 +162,8 @@ def _torch_ssm_transform_prefill(
162162
return y, ssm_state
163163

164164

165-
@torch.library.custom_op("auto_deploy::torch_ssm_transform", mutates_args={})
166-
def _torch_ssm_transform(
165+
@torch.library.custom_op("auto_deploy::torch_ssm", mutates_args={})
166+
def _torch_ssm(
167167
hidden_states: torch.Tensor,
168168
A: torch.Tensor,
169169
B: torch.Tensor,
@@ -176,14 +176,12 @@ def _torch_ssm_transform(
176176
], # NOTE: `torch` custom ops do not like `Tuple` inputs. Using `List` is the suggested WAR.
177177
chunk_size: int,
178178
) -> torch.Tensor:
179-
y, _ = _torch_ssm_transform_prefill(
180-
hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size
181-
)
179+
y, _ = _torch_ssm_prefill(hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size)
182180
return y
183181

184182

185-
@_torch_ssm_transform.register_fake
186-
def _torch_ssm_transform_meta(
183+
@_torch_ssm.register_fake
184+
def _torch_ssm_meta(
187185
hidden_states: torch.Tensor,
188186
A: torch.Tensor,
189187
B: torch.Tensor,

0 commit comments

Comments
 (0)