Skip to content

Commit 9ae0936

Browse files
committed
support mtp eagle with 2 models style
Signed-off-by: qgai <qgai@nvidia.com>
1 parent 1e72721 commit 9ae0936

File tree

11 files changed

+1075
-854
lines changed

11 files changed

+1075
-854
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,27 @@ def setup_llm(args, **kwargs):
169169
) if args.spec_decode_algo is not None else None
170170

171171
if spec_decode_algo == 'MTP':
172+
172173
if not args.use_one_model:
173-
print(
174-
"MTP only supports one model style spec decode; ignoring default use_one_model=False"
174+
print("Running MTP eagle with two model style.")
175+
spec_config = EagleDecodingConfig(
176+
max_draft_len=args.spec_decode_max_draft_len,
177+
speculative_model_dir=args.model_dir,
178+
eagle3_one_model=args.use_one_model,
179+
is_mtp_eagle=True,
180+
use_relaxed_acceptance_for_thinking=args.
181+
use_relaxed_acceptance_for_thinking,
182+
relaxed_topk=args.relaxed_topk,
183+
relaxed_delta=args.relaxed_delta,
175184
)
176-
177-
spec_config = MTPDecodingConfig(
178-
num_nextn_predict_layers=args.spec_decode_max_draft_len,
179-
use_relaxed_acceptance_for_thinking=args.
180-
use_relaxed_acceptance_for_thinking,
181-
relaxed_topk=args.relaxed_topk,
182-
relaxed_delta=args.relaxed_delta)
185+
else:
186+
spec_config = MTPDecodingConfig(
187+
num_nextn_predict_layers=args.spec_decode_max_draft_len,
188+
use_relaxed_acceptance_for_thinking=args.
189+
use_relaxed_acceptance_for_thinking,
190+
relaxed_topk=args.relaxed_topk,
191+
relaxed_delta=args.relaxed_delta,
192+
mtp_eagle_one_model=args.use_one_model)
183193
elif spec_decode_algo == "EAGLE3":
184194
spec_config = EagleDecodingConfig(
185195
max_draft_len=args.spec_decode_max_draft_len,

tensorrt_llm/_torch/models/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def from_config(
2020
"") # Strip the appended EAGLE3
2121
if hasattr(config.pretrained_config, "draft_vocab_size"):
2222
model_arch = "EAGLE3" + model_arch
23+
if model_arch == "DeepseekV3ForCausalLM" and config.spec_config.max_draft_len == 0:
24+
model_arch = "MTPDraftModelForCausalLM"
2325

2426
cls = MODEL_CLASS_MAPPING.get(model_arch)
2527
if cls is None:

0 commit comments

Comments
 (0)