@@ -92,6 +92,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
9292 trtllm_serve_path , model_name , "--host" , "localhost" , "--backend" ,
9393 "pytorch"
9494 ]
95+
9596 if tensor_parallel_size > 1 :
9697 common_args .append (f"--tp_size={ tensor_parallel_size } " )
9798
@@ -104,18 +105,22 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
104105 env_gen ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
105106 env_gen ["CUDA_VISIBLE_DEVICES" ] = "," .join (
106107 map (str , range (tensor_parallel_size , 2 * tensor_parallel_size )))
107-
108- with (MyThreadPoolExecutor (max_workers = 16 ) as thread_pool , temp_dir ,
109- popen (common_args + [
110- "--port" , "8001" , "--extra_llm_api_options" ,
111- ctx_server_config_path
112- ],
113- env = env_ctx ) as ctx_server ,
114- popen (common_args + [
115- "--port" , "8002" , "--extra_llm_api_options" ,
116- gen_server_config_path
117- ],
118- env = env_gen ) as gen_server ,
108+ ctx_server_args = common_args + [
109+ "--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path
110+ ]
111+ gen_server_args = common_args + [
112+ "--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path
113+ ]
114+ if "max_num_tokens" in ctx_server_config :
115+ ctx_server_args .append (
116+ f"--max_num_tokens={ ctx_server_config ['max_num_tokens' ]} " )
117+ if "max_num_tokens" in gen_server_config :
118+ gen_server_args .append (
119+ f"--max_num_tokens={ gen_server_config ['max_num_tokens' ]} " )
120+
121+ with (MyThreadPoolExecutor (max_workers = 16 ) as
122+ thread_pool , temp_dir , popen (ctx_server_args , env = env_ctx ) as
123+ ctx_server , popen (gen_server_args , env = env_gen ) as gen_server ,
119124 popen ([
120125 trtllm_serve_path , "disaggregated" , "-c" ,
121126 disaggregated_serving_config_path , "--server_start_timeout" ,
@@ -209,9 +214,53 @@ def test_auto_dtype(self, disable_overlap_scheduler):
209214 task = GSM8K (self .MODEL_NAME )
210215 task .evaluate (llm )
211216
217+ @pytest .mark .parametrize ("overlap_scheduler" , [False ])
218+ def test_eagle3 (self , overlap_scheduler ):
219+ sepculative_decoding_config = {
220+ "decoding_type" : "Eagle" ,
221+ "max_draft_len" : 4 ,
222+ "pytorch_weights_path" :
223+ f"{ llm_models_root ()} /EAGLE3-LLaMA3.1-Instruct-8B" ,
224+ "eagle3_one_model" : False
225+ }
226+ kv_cache_config = {
227+ "free_gpu_memory_fraction" : 0.5 ,
228+ "enable_block_reuse" : False
229+ }
230+ ctx_server_config = {
231+ "disable_overlap_scheduler" : True ,
232+ "speculative_config" : sepculative_decoding_config ,
233+ "kv_cache_config" : kv_cache_config ,
234+ "max_num_tokens" : 13393 * 2
235+ }
236+ gen_server_config = {
237+ "disable_overlap_scheduler" : not overlap_scheduler ,
238+ "speculative_config" : sepculative_decoding_config ,
239+ "kv_cache_config" : kv_cache_config ,
240+ "max_num_tokens" : 13393 * 2
241+ }
242+ disaggregated_server_config = {
243+ "hostname" : "localhost" ,
244+ "port" : 8000 ,
245+ "backend" : "pytorch" ,
246+ "context_servers" : {
247+ "num_instances" : 1 ,
248+ "urls" : ["localhost:8001" ]
249+ },
250+ "generation_servers" : {
251+ "num_instances" : 1 ,
252+ "urls" : ["localhost:8002" ]
253+ }
254+ }
255+ with launch_disaggregated_llm (disaggregated_server_config ,
256+ ctx_server_config , gen_server_config ,
257+ self .MODEL_PATH ) as llm :
258+ task = GSM8K (self .MODEL_NAME )
259+ task .evaluate (llm )
260+
212261
213- @pytest .mark .timeout (3600 )
214262@pytest .mark .skip_less_device_memory (140000 )
263+ @pytest .mark .timeout (3600 )
215264class TestLlama4ScoutInstruct (LlmapiAccuracyTestHarness ):
216265 MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
217266 MODEL_PATH = f"{ llm_models_root ()} /llama4-models/Llama-4-Scout-17B-16E-Instruct"
0 commit comments