Skip to content

Commit 0f2f11f

Browse files
authored
[TRTLLM-6453][feat] Support chunked prefill on spec decode 2 model (#6104)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
1 parent 9a99e6d commit 0f2f11f

File tree

4 files changed

+71
-22
lines changed

4 files changed

+71
-22
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def __init__(
303303
self.py_batch_idx = None
304304
self.py_rewind_len = 0
305305
self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens
306+
self.py_last_context_chunk = (None, None)
306307
self.py_last_draft_tokens = None
307308
self.py_num_accepted_draft_tokens = 0
308309
self.py_decoding_iter = 0

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,10 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
13161316

13171317
for request in scheduled_requests.context_requests:
13181318
if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests
1319+
request.py_last_context_chunk = (
1320+
request.context_current_position,
1321+
request.context_current_position +
1322+
request.context_chunk_size)
13191323
request.move_to_next_context_chunk()
13201324
if request.context_remaining_length == 0:
13211325
request.state = LlmRequestState.GENERATION_IN_PROGRESS

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,17 @@ def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]:
9292
def _create_context_request(self, request: LlmRequest,
9393
input_tokens: Any) -> LlmRequest:
9494
"""Create a context request for first-time drafting."""
95-
return self._create_draft_request(request.py_request_id,
96-
request.py_max_new_tokens,
97-
input_tokens, request.sampling_config,
98-
request.return_perf_metrics)
95+
new_request = self._create_draft_request(request.py_request_id,
96+
request.py_max_new_tokens,
97+
input_tokens,
98+
request.sampling_config,
99+
request.return_perf_metrics)
100+
101+
begin_compute, end_compute = request.py_last_context_chunk
102+
if begin_compute is not None:
103+
new_request.context_current_position = begin_compute
104+
new_request.context_chunk_size = end_compute - begin_compute
105+
return new_request
99106

100107
def _create_generation_request(self, request: LlmRequest,
101108
input_tokens: Any) -> LlmRequest:
@@ -110,10 +117,13 @@ def _create_generation_request(self, request: LlmRequest,
110117
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
111118
return new_request
112119

113-
def _create_chunked_context_request(self, request: LlmRequest,
120+
def _create_accepted_tokens_request(self, request: LlmRequest,
114121
input_tokens: Any,
115122
num_accepted_tokens: int) -> LlmRequest:
116-
"""Create a chunked context request when some tokens were accepted."""
123+
"""
124+
Create a chunked context request for accepted tokens.
125+
Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3)
126+
"""
117127
new_request = self._create_draft_request(request.py_request_id,
118128
request.py_max_new_tokens,
119129
input_tokens,
@@ -146,7 +156,7 @@ def _create_draft_request_for_request(
146156

147157
# Tokens accepted - chunked context request
148158
else:
149-
return self._create_chunked_context_request(request, input_tokens,
159+
return self._create_accepted_tokens_request(request, input_tokens,
150160
num_accepted_tokens)
151161

152162
def _add_to_draft_batch(self, draft_batch: ScheduledRequests,
@@ -184,6 +194,22 @@ def _prepare_draft_batch(
184194
try:
185195
draft_batch = ScheduledRequests()
186196

197+
for request in scheduled_requests.context_requests:
198+
if request.is_first_context_chunk:
199+
# Ignore requests which still need to be processed by the target model.
200+
continue
201+
202+
# We hit this path if we're doing chunked prefill. The target model processed
203+
# a prefill chunk on the last iteration. Now, we need to fill in the KV cache
204+
# for the draft model too.
205+
all_tokens = request.get_tokens()[0]
206+
input_tokens = get_draft_model_prompt(
207+
self.spec_config.spec_dec_mode, all_tokens)
208+
209+
new_request = self._create_context_request(
210+
request, input_tokens)
211+
self._add_to_draft_batch(draft_batch, new_request, request)
212+
187213
for request in scheduled_requests.generation_requests:
188214
if request.py_draft_pages_allocated == 0:
189215
# No space for draft tokens
@@ -273,6 +299,12 @@ def _process_decoded_tokens(
273299
new_requests = []
274300
for req in draft_batch.all_requests():
275301
target_model_req = req_id_to_old_request[req.py_request_id]
302+
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
303+
# This is a chunked prefill request and we have more prefill chunks
304+
# to process. Defer adding draft tokens until the whole prompt is processed.
305+
self.draft_seq_slot_manager.free_resources(req)
306+
continue
307+
276308
target_model_req.py_draft_tokens.append(req.get_last_tokens(0))
277309
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
278310
target_model_req.py_draft_tokens

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@
1414

1515

1616
@pytest.mark.parametrize(
17-
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model",
17+
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
1818
[
19-
[True, "TRTLLM", True, False, False],
20-
[False, "TRTLLM", True, False, False],
21-
[True, "TRTLLM", True, True, False],
22-
[False, "TRTLLM", True, True, False],
23-
[True, "FLASHINFER", True, False, False],
24-
[False, "FLASHINFER", True, False, False],
25-
[False, "TRTLLM", False, True, True],
26-
[True, "TRTLLM", False, True, True],
19+
[True, "TRTLLM", True, False, False, False],
20+
[False, "TRTLLM", True, False, False, False],
21+
[True, "FLASHINFER", True, False, False, False],
22+
[False, "FLASHINFER", True, False, False, False],
23+
[False, "TRTLLM", False, True, True, False],
24+
[True, "TRTLLM", False, True, True, False],
25+
[True, "TRTLLM", True, False, True, True],
26+
[True, "TRTLLM", True, False, False, True],
2727
])
2828
@pytest.mark.high_cuda_memory
2929
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
3030
disable_overlap_scheduler: bool, enable_block_reuse: bool,
31-
use_one_model: bool):
31+
use_one_model: bool, enable_chunked_prefill: bool):
3232
# Eagle3 one model works with overlap scheduler and block reuse.
3333
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
3434
if total_mem_gb < 35:
@@ -59,7 +59,11 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
5959
# that the draft model won't go above its max in warmup
6060
# in this test.
6161
max_seq_len=8192,
62+
enable_chunked_prefill=enable_chunked_prefill,
6263
)
64+
if enable_chunked_prefill:
65+
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
66+
llm_common_config['max_num_tokens'] = 64
6367

6468
spec_config = EagleDecodingConfig(
6569
max_draft_len=max_draft_len,
@@ -71,7 +75,19 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
7175
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
7276

7377
# Acceptance rate tests
74-
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
78+
if enable_chunked_prefill:
79+
# Use a long prompt for chunked prefill tests.
80+
prompts = [
81+
"The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and "
82+
]
83+
tok_ids = llm_spec.tokenizer.encode(prompts[0])
84+
else:
85+
prompts = [
86+
"The capital of France is",
87+
"The president of the United States is",
88+
]
89+
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
90+
7591
num_tokens = 0
7692
num_drafted = 0
7793
num_accepted = 0
@@ -88,10 +104,6 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
88104
assert accept_rate > 0.15
89105

90106
# Output tests
91-
prompts = [
92-
"The capital of France is",
93-
"The president of the United States is",
94-
]
95107
sampling_params = SamplingParams(max_tokens=10, temperature=0)
96108

97109
results_spec = llm_spec.generate(prompts, sampling_params)

0 commit comments

Comments
 (0)