@@ -169,17 +169,27 @@ def setup_llm(args, **kwargs):
169169 ) if args .spec_decode_algo is not None else None
170170
171171 if spec_decode_algo == 'MTP' :
172+
172173 if not args .use_one_model :
173- print (
174- "MTP only supports one model style spec decode; ignoring default use_one_model=False"
174+ print ("Running MTP eagle with two model style." )
175+ spec_config = EagleDecodingConfig (
176+ max_draft_len = args .spec_decode_max_draft_len ,
177+ speculative_model_dir = args .model_dir ,
178+ eagle3_one_model = args .use_one_model ,
179+ is_mtp_eagle = True ,
180+ use_relaxed_acceptance_for_thinking = args .
181+ use_relaxed_acceptance_for_thinking ,
182+ relaxed_topk = args .relaxed_topk ,
183+ relaxed_delta = args .relaxed_delta ,
175184 )
176-
177- spec_config = MTPDecodingConfig (
178- num_nextn_predict_layers = args .spec_decode_max_draft_len ,
179- use_relaxed_acceptance_for_thinking = args .
180- use_relaxed_acceptance_for_thinking ,
181- relaxed_topk = args .relaxed_topk ,
182- relaxed_delta = args .relaxed_delta )
185+ else :
186+ spec_config = MTPDecodingConfig (
187+ num_nextn_predict_layers = args .spec_decode_max_draft_len ,
188+ use_relaxed_acceptance_for_thinking = args .
189+ use_relaxed_acceptance_for_thinking ,
190+ relaxed_topk = args .relaxed_topk ,
191+ relaxed_delta = args .relaxed_delta ,
192+ mtp_eagle_one_model = args .use_one_model )
183193 elif spec_decode_algo == "EAGLE3" :
184194 spec_config = EagleDecodingConfig (
185195 max_draft_len = args .spec_decode_max_draft_len ,
0 commit comments