Skip to content

Commit 81a4b33

Browse files
Tabrizianmikeiovine
authored andcommitted
[nvbug/5337601][fix] Fix disagg + speculative decoding (NVIDIA#5558)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Co-authored-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
1 parent af4eb9d commit 81a4b33

File tree

3 files changed

+68
-19
lines changed

3 files changed

+68
-19
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -879,8 +879,9 @@ def _executor_loop(self):
879879

880880
self._pad_attention_dp_dummy_request()
881881

882-
if self.draft_model_engine is not None or self.drafter is not None:
883-
self._prepare_draft_requests()
882+
if self.draft_model_engine is not None or hasattr(
883+
self, 'drafter') and self.drafter is not None:
884+
self._prepare_draft_requests(self.active_requests)
884885

885886
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
886887
)
@@ -969,12 +970,11 @@ def _executor_loop(self):
969970
iter_stats=iter_stats,
970971
iter_start_time=iter_start_time))
971972

972-
def _prepare_draft_requests(self):
973+
def _prepare_draft_requests(self, requests):
973974
try:
974975
# Set draft tokens here to make the KV cache manager
975976
# and scheduler aware of them.
976-
for req in self.active_requests:
977-
# TODO: enable draft tokens in context phase
977+
for req in requests:
978978
if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS,
979979
LlmRequestState.DISAGG_GENERATION_INIT):
980980
continue
@@ -1786,7 +1786,6 @@ def create_new_request(input_tokens):
17861786
# This is the first time the draft model is seeing this request.
17871787
# Prepare a context request. We discard the first token and take
17881788
# the newly decoded one - this is the convention for EAGLE 2 and 3.
1789-
assert num_draft_tokens == 0
17901789
new_request = create_new_request(input_tokens)
17911790
draft_batch.context_requests.append(new_request)
17921791
elif num_accepted_tokens == 0:

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
9191
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
9292
"pytorch"
9393
]
94+
9495
if tensor_parallel_size > 1:
9596
common_args.append(f"--tp_size={tensor_parallel_size}")
9697

@@ -103,18 +104,22 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
103104
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
104105
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
105106
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))
106-
107-
with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir,
108-
popen(common_args + [
109-
"--port", "8001", "--extra_llm_api_options",
110-
ctx_server_config_path
111-
],
112-
env=env_ctx) as ctx_server,
113-
popen(common_args + [
114-
"--port", "8002", "--extra_llm_api_options",
115-
gen_server_config_path
116-
],
117-
env=env_gen) as gen_server,
107+
ctx_server_args = common_args + [
108+
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
109+
]
110+
gen_server_args = common_args + [
111+
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
112+
]
113+
if "max_num_tokens" in ctx_server_config:
114+
ctx_server_args.append(
115+
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")
116+
if "max_num_tokens" in gen_server_config:
117+
gen_server_args.append(
118+
f"--max_num_tokens={gen_server_config['max_num_tokens']}")
119+
120+
with (MyThreadPoolExecutor(max_workers=16) as
121+
thread_pool, temp_dir, popen(ctx_server_args, env=env_ctx) as
122+
ctx_server, popen(gen_server_args, env=env_gen) as gen_server,
118123
popen([
119124
trtllm_serve_path, "disaggregated", "-c",
120125
disaggregated_serving_config_path, "--server_start_timeout",
@@ -252,9 +257,53 @@ def test_ngram(self):
252257
task = GSM8K(self.MODEL_NAME)
253258
task.evaluate(llm)
254259

260+
@pytest.mark.parametrize("overlap_scheduler", [False])
261+
def test_eagle3(self, overlap_scheduler):
262+
speculative_decoding_config = {
263+
"decoding_type": "Eagle",
264+
"max_draft_len": 4,
265+
"pytorch_weights_path":
266+
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
267+
"eagle3_one_model": False
268+
}
269+
kv_cache_config = {
270+
"free_gpu_memory_fraction": 0.5,
271+
"enable_block_reuse": False
272+
}
273+
ctx_server_config = {
274+
"disable_overlap_scheduler": True,
275+
"speculative_config": speculative_decoding_config,
276+
"kv_cache_config": kv_cache_config,
277+
"max_num_tokens": 13393 * 2
278+
}
279+
gen_server_config = {
280+
"disable_overlap_scheduler": not overlap_scheduler,
281+
"speculative_config": speculative_decoding_config,
282+
"kv_cache_config": kv_cache_config,
283+
"max_num_tokens": 13393 * 2
284+
}
285+
disaggregated_server_config = {
286+
"hostname": "localhost",
287+
"port": 8000,
288+
"backend": "pytorch",
289+
"context_servers": {
290+
"num_instances": 1,
291+
"urls": ["localhost:8001"]
292+
},
293+
"generation_servers": {
294+
"num_instances": 1,
295+
"urls": ["localhost:8002"]
296+
}
297+
}
298+
with launch_disaggregated_llm(disaggregated_server_config,
299+
ctx_server_config, gen_server_config,
300+
self.MODEL_PATH) as llm:
301+
task = GSM8K(self.MODEL_NAME)
302+
task.evaluate(llm)
303+
255304

256-
@pytest.mark.timeout(3600)
257305
@pytest.mark.skip_less_device_memory(140000)
306+
@pytest.mark.timeout(3600)
258307
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
259308
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
260309
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ l0_dgx_h100:
3535
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
3636
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
3737
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
38+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False]
3839
- test_e2e.py::test_ptp_quickstart_advanced_bs1
3940
- condition:
4041
ranges:

0 commit comments

Comments
 (0)