Skip to content

Commit 97ce0ec

Browse files
authored
[TRTLLM-8436][feat] batched sampling and top-k logprobs improvements (#8398)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent d05079b commit 97ce0ec

File tree

8 files changed

+1162
-996
lines changed

8 files changed

+1162
-996
lines changed

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ extend_skip_glob = [
3535
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
3636
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
3737
"tests/unittest/_torch/sampler/test_torch_sampler.py",
38+
"tensorrt_llm/_torch/pyexecutor/sampler.py",
39+
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
3840
]
3941

4042
[tool.yapf]
@@ -67,6 +69,8 @@ ignore_patterns = [
6769
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
6870
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
6971
"tests/unittest/_torch/sampler/test_torch_sampler.py",
72+
"tensorrt_llm/_torch/pyexecutor/sampler.py",
73+
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
7074
]
7175

7276
[tool.codespell]
@@ -102,6 +106,8 @@ exclude = [
102106
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
103107
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
104108
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
109+
"tensorrt_llm/_torch/pyexecutor/sampler.py",
110+
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
105111
]
106112

107113

@@ -147,6 +153,8 @@ include = [
147153
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
148154
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
149155
"tests/unittest/_torch/sampler/test_torch_sampler.py",
156+
"tensorrt_llm/_torch/pyexecutor/sampler.py",
157+
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
150158
]
151159
exclude = [
152160
"**3rdparty/**",

tensorrt_llm/_torch/auto_deploy/shim/demollm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ....executor.result import CompletionOutput, GenerationResult
1414
from ....inputs.multimodal import MultimodalParams
1515
from ....sampling_params import SamplingParams
16-
from ...pyexecutor.sampler import greedy_search_sampling_batch, top_k_sampling_batch
16+
from ...pyexecutor.sampling_utils import greedy_search_sampling_batch, top_k_sampling_batch
1717
from ..distributed import common as dist_ad
1818
from ..utils.logger import ad_logger
1919
from .ad_executor import ADEngine

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,13 +1204,12 @@ def _executor_loop(self):
12041204

12051205
self._kv_connector_terminate_requests()
12061206

1207-
if self.enable_iter_perf_stats:
1207+
if self.enable_iter_perf_stats and sample_state is not None:
12081208
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
12091209
'num_ctx_tokens']
12101210
self._process_iter_stats(
12111211
finished_requests, self.active_requests,
1212-
BatchState(sample_state=SampleState(
1213-
scheduled_requests=scheduled_batch),
1212+
BatchState(sample_state=sample_state,
12141213
iter_stats=iter_stats,
12151214
iter_start_time=iter_start_time))
12161215

0 commit comments

Comments
 (0)