@@ -91,6 +91,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
9191 trtllm_serve_path , model_name , "--host" , "localhost" , "--backend" ,
9292 "pytorch"
9393 ]
94+
9495 if tensor_parallel_size > 1 :
9596 common_args .append (f"--tp_size={ tensor_parallel_size } " )
9697
@@ -103,18 +104,22 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
103104 env_gen ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
104105 env_gen ["CUDA_VISIBLE_DEVICES" ] = "," .join (
105106 map (str , range (tensor_parallel_size , 2 * tensor_parallel_size )))
106-
107- with (MyThreadPoolExecutor (max_workers = 16 ) as thread_pool , temp_dir ,
108- popen (common_args + [
109- "--port" , "8001" , "--extra_llm_api_options" ,
110- ctx_server_config_path
111- ],
112- env = env_ctx ) as ctx_server ,
113- popen (common_args + [
114- "--port" , "8002" , "--extra_llm_api_options" ,
115- gen_server_config_path
116- ],
117- env = env_gen ) as gen_server ,
107+ ctx_server_args = common_args + [
108+ "--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path
109+ ]
110+ gen_server_args = common_args + [
111+ "--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path
112+ ]
113+ if "max_num_tokens" in ctx_server_config :
114+ ctx_server_args .append (
115+ f"--max_num_tokens={ ctx_server_config ['max_num_tokens' ]} " )
116+ if "max_num_tokens" in gen_server_config :
117+ gen_server_args .append (
118+ f"--max_num_tokens={ gen_server_config ['max_num_tokens' ]} " )
119+
120+ with (MyThreadPoolExecutor (max_workers = 16 ) as
121+ thread_pool , temp_dir , popen (ctx_server_args , env = env_ctx ) as
122+ ctx_server , popen (gen_server_args , env = env_gen ) as gen_server ,
118123 popen ([
119124 trtllm_serve_path , "disaggregated" , "-c" ,
120125 disaggregated_serving_config_path , "--server_start_timeout" ,
@@ -252,9 +257,53 @@ def test_ngram(self):
252257 task = GSM8K (self .MODEL_NAME )
253258 task .evaluate (llm )
254259
260+ @pytest .mark .parametrize ("overlap_scheduler" , [False ])
261+ def test_eagle3 (self , overlap_scheduler ):
262+ speculative_decoding_config = {
263+ "decoding_type" : "Eagle" ,
264+ "max_draft_len" : 4 ,
265+ "pytorch_weights_path" :
266+ f"{ llm_models_root ()} /EAGLE3-LLaMA3.1-Instruct-8B" ,
267+ "eagle3_one_model" : False
268+ }
269+ kv_cache_config = {
270+ "free_gpu_memory_fraction" : 0.5 ,
271+ "enable_block_reuse" : False
272+ }
273+ ctx_server_config = {
274+ "disable_overlap_scheduler" : True ,
275+ "speculative_config" : speculative_decoding_config ,
276+ "kv_cache_config" : kv_cache_config ,
277+ "max_num_tokens" : 13393 * 2
278+ }
279+ gen_server_config = {
280+ "disable_overlap_scheduler" : not overlap_scheduler ,
281+ "speculative_config" : speculative_decoding_config ,
282+ "kv_cache_config" : kv_cache_config ,
283+ "max_num_tokens" : 13393 * 2
284+ }
285+ disaggregated_server_config = {
286+ "hostname" : "localhost" ,
287+ "port" : 8000 ,
288+ "backend" : "pytorch" ,
289+ "context_servers" : {
290+ "num_instances" : 1 ,
291+ "urls" : ["localhost:8001" ]
292+ },
293+ "generation_servers" : {
294+ "num_instances" : 1 ,
295+ "urls" : ["localhost:8002" ]
296+ }
297+ }
298+ with launch_disaggregated_llm (disaggregated_server_config ,
299+ ctx_server_config , gen_server_config ,
300+ self .MODEL_PATH ) as llm :
301+ task = GSM8K (self .MODEL_NAME )
302+ task .evaluate (llm )
303+
255304
256- @pytest .mark .timeout (3600 )
257305@pytest .mark .skip_less_device_memory (140000 )
306+ @pytest .mark .timeout (3600 )
258307class TestLlama4ScoutInstruct (LlmapiAccuracyTestHarness ):
259308 MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
260309 MODEL_PATH = f"{ llm_models_root ()} /llama4-models/Llama-4-Scout-17B-16E-Instruct"
0 commit comments