Skip to content
Prev Previous commit
Next Next commit
rename
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
  • Loading branch information
syuoni committed Sep 3, 2025
commit 88b3fd3f6109527fc75e6161c24a07ca8c4020f5
9 changes: 5 additions & 4 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
WeightsLoadingConfig)
from ..modules.rms_norm import RMSNorm
from ..pyexecutor.guided_decoder import GuidedWorker
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..speculative import SpecMetadata, get_spec_worker
from .checkpoints.base_weight_mapper import BaseWeightMapper
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
Expand Down Expand Up @@ -466,7 +466,8 @@ def load_draft_weights(self,
weight_mapper=weight_mapper)
self.draft_model.load_weights_from_target_model(self)

def set_guided_worker(self, guided_worker: GuidedWorker) -> bool:
if hasattr(self.spec_worker, "set_guided_worker"):
return self.spec_worker.set_guided_worker(guided_worker)
def set_guided_decoder(self,
guided_decoder: CapturableGuidedDecoder) -> bool:
if hasattr(self.spec_worker, "set_guided_decoder"):
return self.spec_worker.set_guided_decoder(guided_decoder)
return False
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def init_disagg_gen_requests(self) -> None:
self._init_disagg_gen_requests(self.requests)


class GuidedWorker(GuidedDecoder):
class CapturableGuidedDecoder(GuidedDecoder):

def __init__(self,
guided_decoding_config: GuidedDecodingConfig,
Expand All @@ -401,6 +401,8 @@ def __init__(self,
max_num_draft_tokens: int = 0):
super().__init__(guided_decoding_config, max_num_sequences,
vocab_size_padded, max_num_draft_tokens)
# self.requests should be accessed by normal host code;
# self.requests_hostfunc should be accessed by hostfunc (CUDA callback).
self.requests_hostfunc: Optional[GuidedRequests] = None
self.queue = Queue()

Expand Down
23 changes: 12 additions & 11 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from .config import LoadFormat, PyTorchConfig
from .config_utils import is_mla
from .cuda_graph_runner import CUDAGraphRunner
from .guided_decoder import GuidedWorker
from .guided_decoder import CapturableGuidedDecoder
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .llm_request import get_draft_token_length
from .resource_manager import (BaseResourceManager, KVCacheManager,
Expand Down Expand Up @@ -413,7 +413,7 @@ def __init__(
self.without_logits = False
self.max_draft_len = 0

self.guided_worker: Optional[GuidedWorker] = None
self.guided_decoder: Optional[CapturableGuidedDecoder] = None

# This field is initialized lazily on the first forward pass.
# This is convenient because:
Expand Down Expand Up @@ -484,11 +484,12 @@ def set_lora_model_config(self,
dtype=torch_dtype_to_str(self.model.config.torch_dtype),
swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight)

def set_guided_worker(self, guided_worker: GuidedWorker) -> bool:
if hasattr(self.model, "set_guided_worker"):
success = self.model.set_guided_worker(guided_worker)
def set_guided_decoder(self,
guided_decoder: CapturableGuidedDecoder) -> bool:
if hasattr(self.model, "set_guided_decoder"):
success = self.model.set_guided_decoder(guided_decoder)
if success:
self.guided_worker = guided_worker
self.guided_decoder = guided_decoder
return success
return False

Expand Down Expand Up @@ -1181,8 +1182,8 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
num_ctx_requests:num_seqs] += (
self.previous_kv_lens_offsets_cuda[:num_gen_requests])

if self.guided_worker is not None:
self.guided_worker.token_event.record()
if self.guided_decoder is not None:
self.guided_decoder.token_event.record()

return inputs

Expand Down Expand Up @@ -1271,9 +1272,9 @@ def _prepare_tp_inputs(
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]

# Must be before the update of py_batch_idx
if self.guided_worker is not None:
self.guided_worker.add_batch(scheduled_requests,
new_tokens=new_tokens_device)
if self.guided_decoder is not None:
self.guided_decoder.add_batch(scheduled_requests,
new_tokens=new_tokens_device)

# if new_tensors_device exist, input_ids will only contain new context tokens
input_ids = [] # per sequence
Expand Down
14 changes: 9 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
create_py_executor_instance, instantiate_sampler, is_mla)
from .config import LoadFormat, PyTorchConfig
from .config_utils import is_mla
from .guided_decoder import GuidedDecoder, GuidedWorker
from .guided_decoder import CapturableGuidedDecoder, GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor
Expand Down Expand Up @@ -402,15 +402,19 @@ def drafting_loop_wrapper(model):
}
if spec_config is not None:
kwargs["max_num_draft_tokens"] = spec_config.max_draft_len

if spec_config is None or spec_config.spec_dec_mode.support_guided_decoder(
):
# GuidedDecoder is applicable to non-speculative decoding and two-model speculative decoding.
guided_decoder = GuidedDecoder(**kwargs)
elif spec_config.spec_dec_mode.support_guided_worker():
success = model_engine.set_guided_worker(
GuidedWorker(**kwargs))
elif spec_config.spec_dec_mode.support_capturable_guided_decoder(
):
# CapturableGuidedDecoder is applicable to one-model speculative decoding.
success = model_engine.set_guided_decoder(
CapturableGuidedDecoder(**kwargs))
if not success:
raise ValueError(
f"Failed to set guided worker for speculative decoding mode: {spec_config.spec_dec_mode.name}."
f"Failed to set guided decoder for speculative decoding mode: {spec_config.spec_dec_mode.name}."
)
else:
raise ValueError(
Expand Down
25 changes: 13 additions & 12 deletions tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tensorrt_llm.mapping import Mapping

from ..attention_backend import AttentionMetadata
from ..pyexecutor.guided_decoder import GuidedWorker
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import TorchSampler
Expand Down Expand Up @@ -276,7 +276,7 @@ def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
self.spec_config = spec_config
self.max_draft_len = self.spec_config.max_draft_len
self.mapping = mapping
self.guided_worker: Optional[GuidedWorker] = None
self.guided_decoder: Optional[CapturableGuidedDecoder] = None

# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
# @torch.compile(options={"max-autotune": True})
Expand All @@ -288,8 +288,8 @@ def forward(self, input_ids, position_ids, hidden_states, logits,

raw_logits = logits

if self.guided_worker is not None:
self.guided_worker.execute(logits)
if self.guided_decoder is not None:
self.guided_decoder.execute(logits)

# Sample and accept tokens
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
Expand Down Expand Up @@ -329,11 +329,11 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]

if self.guided_worker is not None:
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_worker.add_draft_batch(new_tokens,
num_accepted_tokens,
is_first_step=(i == 0))
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
is_first_step=(i == 0))

hidden_states, hidden_states_to_save = draft_model.model(**inputs)

Expand All @@ -347,9 +347,9 @@ def forward(self, input_ids, position_ids, hidden_states, logits,
logits = draft_model.logits_processor(hidden_states[gather_ids],
draft_model.lm_head,
attn_metadata, True)
if self.guided_worker is not None:
if self.guided_decoder is not None:
d2t = getattr(draft_model.model, "d2t", None)
self.guided_worker.execute_draft_batch(
self.guided_decoder.execute_draft_batch(
logits,
d2t,
is_first_step=(i == 0),
Expand Down Expand Up @@ -524,6 +524,7 @@ def prepare_1st_drafter_inputs(
"spec_metadata": spec_metadata,
}

def set_guided_worker(self, guided_worker: GuidedWorker) -> bool:
self.guided_worker = guided_worker
def set_guided_decoder(self,
guided_decoder: CapturableGuidedDecoder) -> bool:
self.guided_decoder = guided_decoder
return True
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def support_overlap_scheduler(self):
return self.is_mtp() or self.is_eagle3_one_model()

def support_guided_decoder(self):
return self.has_spec_drafter()
return self.is_none() or self.has_spec_drafter()

def support_guided_worker(self):
def support_capturable_guided_decoder(self):
return self.is_mtp() or self.is_eagle3_one_model()

def has_draft_model(self):
Expand Down
41 changes: 21 additions & 20 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..attention_backend import AttentionMetadata
from ..distributed.ops import allgather
from ..model_config import ModelConfig
from ..pyexecutor.guided_decoder import GuidedWorker
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
Expand Down Expand Up @@ -327,7 +327,7 @@ def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
self.spec_config = spec_config
self.model_config = model_config
self.is_thop = False
self.guided_worker: Optional[GuidedWorker] = None
self.guided_decoder: Optional[CapturableGuidedDecoder] = None

def forward(
self,
Expand Down Expand Up @@ -443,8 +443,8 @@ def forward(

raw_logits = logits

if self.guided_worker is not None:
self.guided_worker.execute(logits)
if self.guided_decoder is not None:
self.guided_decoder.execute(logits)

# Sample and verify draft tokens
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
Expand Down Expand Up @@ -478,18 +478,18 @@ def forward(
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
for i, mtp_layer in enumerate(draft_model.mtp_layers):
if self.guided_worker is not None:
if self.guided_decoder is not None:
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
self.guided_worker.add_draft_batch(new_tokens,
num_accepted_tokens,
is_first_step=(i == 0))
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
is_first_step=(i == 0))

hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
**draft_inputs)
logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head,
attn_metadata).float()
if self.guided_worker is not None:
self.guided_worker.execute_draft_batch(
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(
logits,
is_first_step=(i == 0),
is_last_step=(i == len(draft_model.mtp_layers) - 1))
Expand Down Expand Up @@ -1125,8 +1125,9 @@ def draft_sampler(

return draft_tokens

def set_guided_worker(self, guided_worker: GuidedWorker) -> bool:
self.guided_worker = guided_worker
def set_guided_decoder(self,
guided_decoder: CapturableGuidedDecoder) -> bool:
self.guided_decoder = guided_decoder
return True


Expand Down Expand Up @@ -1164,8 +1165,8 @@ def forward(

raw_logits = logits

if self.guided_worker is not None:
self.guided_worker.execute(logits)
if self.guided_decoder is not None:
self.guided_decoder.execute(logits)

# Sample and verify draft tokens
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
Expand Down Expand Up @@ -1227,17 +1228,17 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]

if self.guided_worker is not None:
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_worker.add_draft_batch(new_tokens,
num_accepted_tokens,
is_first_step=(i == 0))
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
is_first_step=(i == 0))

logits = draft_model.mtp_layers[0].shared_head(
hidden_states[gather_ids], draft_model.lm_head, attn_metadata,
True)
if self.guided_worker is not None:
self.guided_worker.execute_draft_batch(
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(
logits,
is_first_step=(i == 0),
is_last_step=(i == self.mtp_num_modules - 1))
Expand Down