Skip to content

Commit 518915b

Browse files
[nvbug/5337601][fix] Fix disagg + speculative decoding (#5558)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Co-authored-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
1 parent 5ac92bb commit 518915b

File tree

4 files changed

+103
-46
lines changed

4 files changed

+103
-46
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -869,8 +869,9 @@ def _executor_loop(self):
869869

870870
self._pad_attention_dp_dummy_request()
871871

872-
if self.draft_model_engine is not None or is_ngram:
873-
self._prepare_draft_requests()
872+
if self.draft_model_engine is not None or is_ngram or hasattr(
873+
self, 'drafter') and self.drafter is not None:
874+
self._prepare_draft_requests(self.active_requests)
874875

875876
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
876877
)
@@ -966,13 +967,13 @@ def _executor_loop(self):
966967
iter_stats=iter_stats,
967968
iter_start_time=iter_start_time))
968969

969-
def _prepare_draft_requests(self):
970+
def _prepare_draft_requests(self, requests):
970971
try:
971972
# Set draft tokens here to make the KV cache manager
972973
# and scheduler aware of them.
973-
for req in self.active_requests:
974-
# TODO: enable draft tokens in context phase
975-
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
974+
for req in requests:
975+
if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS,
976+
LlmRequestState.DISAGG_GENERATION_INIT):
976977
continue
977978
req.py_last_draft_tokens = req.py_draft_tokens
978979
max_draft_len = self.model_engine.spec_config.max_draft_tokens
@@ -1528,9 +1529,16 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
15281529
disagg_gen_init_to_prepare.generation_requests = []
15291530
disagg_gen_init_to_prepare.paused_requests = []
15301531

1531-
self.resource_manager.resource_managers[
1532-
ResourceManagerType.KV_CACHE_MANAGER].prepare_resources(
1533-
disagg_gen_init_to_prepare)
1532+
for resource_mgr_type in (
1533+
ResourceManagerType.KV_CACHE_MANAGER,
1534+
ResourceManagerType.SEQ_SLOT_MANAGER,
1535+
ResourceManagerType.SPEC_RESOURCE_MANAGER,
1536+
ResourceManagerType.DRAFT_KV_CACHE_MANAGER):
1537+
if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[
1538+
resource_mgr_type] is not None:
1539+
self.resource_manager.resource_managers[
1540+
resource_mgr_type].prepare_resources(
1541+
disagg_gen_init_to_prepare)
15341542

15351543
# Trigger KV cache exchange for new disagg_gen_init_requests
15361544
self._recv_disagg_gen_cache(fitting_disagg_gen_init_requests)
@@ -1790,7 +1798,6 @@ def _prepare_draft_batch(
17901798
# This is the first time the draft model is seeing this request.
17911799
# Prepare a context request. We discard the first token and take
17921800
# the newly decoded one - this is the convention for EAGLE 2 and 3.
1793-
assert num_draft_tokens == 0
17941801
new_request = LlmRequest(
17951802
request_id=request.py_request_id,
17961803
max_new_tokens=request.py_max_new_tokens,

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -307,30 +307,30 @@ def handle_logits(request: LlmRequest, tokens: list[int], count=1):
307307
if request.state != LlmRequestState.GENERATION_COMPLETE:
308308
new_token = new_tokens_list[token_idx]
309309
num_tokens = request.add_new_token(new_token, beam_idx)
310-
if self._handle_stop_criteria(request, new_token, num_tokens,
311-
beam_idx):
312-
continue
313-
314-
# Accept draft tokens (if we have any) if and only if they match the new
315-
# token exactly.
316-
num_accepted = 0
317-
new_tokens = [new_token]
318-
for draft_token in request.py_draft_tokens:
319-
if draft_token != new_token:
320-
# Reject.
321-
break
322-
num_accepted += 1
323-
new_token = new_tokens_list[token_idx + num_accepted]
324-
num_tokens = request.add_new_token(new_token, beam_idx)
325-
new_tokens.append(num_tokens) # `num_tokens`->`new_token`
326-
327-
if self._handle_stop_criteria(request, new_token,
310+
if not self._handle_stop_criteria(request, new_token,
328311
num_tokens, beam_idx):
329-
break
330-
handle_logits(request, new_tokens, num_accepted)
331-
request.py_decoding_iter += 1
332-
request.py_num_accepted_draft_tokens = num_accepted
333-
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
312+
313+
# Accept draft tokens (if we have any) if and only if they match the new
314+
# token exactly.
315+
num_accepted = 0
316+
new_tokens = [new_token]
317+
for draft_token in request.py_draft_tokens:
318+
if draft_token != new_token:
319+
# Reject.
320+
break
321+
num_accepted += 1
322+
new_token = new_tokens_list[token_idx + num_accepted]
323+
num_tokens = request.add_new_token(new_token, beam_idx)
324+
new_tokens.append(
325+
num_tokens) # `num_tokens`->`new_token`
326+
327+
if self._handle_stop_criteria(request, new_token,
328+
num_tokens, beam_idx):
329+
break
330+
handle_logits(request, new_tokens, num_accepted)
331+
request.py_decoding_iter += 1
332+
request.py_num_accepted_draft_tokens = num_accepted
333+
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
334334
advance_idx(len(request.py_draft_tokens) + 1)
335335

336336
for request in generation_requests:

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
9292
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
9393
"pytorch"
9494
]
95+
9596
if tensor_parallel_size > 1:
9697
common_args.append(f"--tp_size={tensor_parallel_size}")
9798

@@ -104,18 +105,22 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
104105
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
105106
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
106107
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))
107-
108-
with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir,
109-
popen(common_args + [
110-
"--port", "8001", "--extra_llm_api_options",
111-
ctx_server_config_path
112-
],
113-
env=env_ctx) as ctx_server,
114-
popen(common_args + [
115-
"--port", "8002", "--extra_llm_api_options",
116-
gen_server_config_path
117-
],
118-
env=env_gen) as gen_server,
108+
ctx_server_args = common_args + [
109+
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
110+
]
111+
gen_server_args = common_args + [
112+
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
113+
]
114+
if "max_num_tokens" in ctx_server_config:
115+
ctx_server_args.append(
116+
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")
117+
if "max_num_tokens" in gen_server_config:
118+
gen_server_args.append(
119+
f"--max_num_tokens={gen_server_config['max_num_tokens']}")
120+
121+
with (MyThreadPoolExecutor(max_workers=16) as
122+
thread_pool, temp_dir, popen(ctx_server_args, env=env_ctx) as
123+
ctx_server, popen(gen_server_args, env=env_gen) as gen_server,
119124
popen([
120125
trtllm_serve_path, "disaggregated", "-c",
121126
disaggregated_serving_config_path, "--server_start_timeout",
@@ -209,9 +214,53 @@ def test_auto_dtype(self, disable_overlap_scheduler):
209214
task = GSM8K(self.MODEL_NAME)
210215
task.evaluate(llm)
211216

217+
@pytest.mark.parametrize("overlap_scheduler", [False])
218+
def test_eagle3(self, overlap_scheduler):
219+
speculative_decoding_config = {
220+
"decoding_type": "Eagle",
221+
"max_draft_len": 4,
222+
"pytorch_weights_path":
223+
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
224+
"eagle3_one_model": False
225+
}
226+
kv_cache_config = {
227+
"free_gpu_memory_fraction": 0.5,
228+
"enable_block_reuse": False
229+
}
230+
ctx_server_config = {
231+
"disable_overlap_scheduler": True,
232+
"speculative_config": speculative_decoding_config,
233+
"kv_cache_config": kv_cache_config,
234+
"max_num_tokens": 13393 * 2
235+
}
236+
gen_server_config = {
237+
"disable_overlap_scheduler": not overlap_scheduler,
238+
"speculative_config": speculative_decoding_config,
239+
"kv_cache_config": kv_cache_config,
240+
"max_num_tokens": 13393 * 2
241+
}
242+
disaggregated_server_config = {
243+
"hostname": "localhost",
244+
"port": 8000,
245+
"backend": "pytorch",
246+
"context_servers": {
247+
"num_instances": 1,
248+
"urls": ["localhost:8001"]
249+
},
250+
"generation_servers": {
251+
"num_instances": 1,
252+
"urls": ["localhost:8002"]
253+
}
254+
}
255+
with launch_disaggregated_llm(disaggregated_server_config,
256+
ctx_server_config, gen_server_config,
257+
self.MODEL_PATH) as llm:
258+
task = GSM8K(self.MODEL_NAME)
259+
task.evaluate(llm)
260+
212261

213-
@pytest.mark.timeout(3600)
214262
@pytest.mark.skip_less_device_memory(140000)
263+
@pytest.mark.timeout(3600)
215264
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
216265
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
217266
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ l0_dgx_h100:
3939
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
4040
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
4141
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
42+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False]
4243
- test_e2e.py::test_ptp_quickstart_advanced_bs1
4344
- condition:
4445
ranges:

0 commit comments

Comments
 (0)