Skip to content

Commit 50d4e5b

Browse files
[TRTLLM-8483][chore] Refine scheduler_config and peft_cache_config in create_py_executor (#8451)
Signed-off-by: leslie-fang25 <leslief@nvidia.com>
1 parent bac9e8c commit 50d4e5b

File tree

7 files changed

+47
-39
lines changed

7 files changed

+47
-39
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from typing import Dict, List, Optional, Tuple
44

55
import torch
6+
from strenum import StrEnum
67
from torch._prims_common import DeviceLikeType
78

89
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
910
from tensorrt_llm._utils import nvtx_range
11+
from tensorrt_llm.llmapi.llm_args import ContextChunkingPolicy
1012

1113
from ...._utils import mpi_rank, mpi_world_size
12-
from ....bindings.executor import ContextChunkingPolicy
13-
from ....bindings.internal.batch_manager import CacheType, ContextChunkingConfig
14+
from ....bindings.internal.batch_manager import CacheType
1415
from ....mapping import Mapping
1516
from ...distributed import MPIDist
1617
from ...pyexecutor.model_engine import ModelEngine
@@ -376,7 +377,7 @@ def create_autodeploy_executor(ad_config: LlmArgs):
376377
if ad_config.enable_chunked_prefill:
377378
chunk_unit_size = ad_config.attn_page_size
378379
chunking_policy = ContextChunkingPolicy.FIRST_COME_FIRST_SERVED
379-
ctx_chunk_config = ContextChunkingConfig(chunking_policy, chunk_unit_size)
380+
ctx_chunk_config: Tuple[StrEnum, int] = (chunking_policy, chunk_unit_size)
380381
else:
381382
ctx_chunk_config = None
382383

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from tensorrt_llm.bindings.executor import DecodingMode
1414
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
1515
MTPDecodingConfig, PeftCacheConfig,
16-
SamplerType, SparseAttentionConfig,
16+
SamplerType, SchedulerConfig,
17+
SparseAttentionConfig,
1718
SpeculativeConfig, TorchLlmArgs)
1819
from tensorrt_llm.logger import logger
1920
from tensorrt_llm.lora_helper import (LoraConfig,
@@ -663,8 +664,8 @@ def create_py_executor_instance(
663664
max_batch_size: Optional[int] = None,
664665
max_beam_width: Optional[int] = None,
665666
max_num_tokens: Optional[int] = None,
666-
peft_cache_config: Optional[trtllm.PeftCacheConfig] = None,
667-
scheduler_config: Optional[trtllm.SchedulerConfig] = None,
667+
peft_cache_config: Optional[PeftCacheConfig] = None,
668+
scheduler_config: Optional[SchedulerConfig] = None,
668669
cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None,
669670
) -> PyExecutor:
670671
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
@@ -728,16 +729,14 @@ def create_py_executor_instance(
728729
num_lora_modules = model_engine.model.model_config.pretrained_config.num_hidden_layers * \
729730
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
730731

731-
peft_cache_config_model = PeftCacheConfig.from_pybind(
732-
peft_cache_config
733-
) if peft_cache_config is not None else PeftCacheConfig()
732+
peft_cache_config_model = PeftCacheConfig(
733+
) if peft_cache_config is None else peft_cache_config
734734
if lora_config.max_loras is not None:
735735
peft_cache_config_model.num_device_module_layer = \
736736
max_lora_rank * num_lora_modules * lora_config.max_loras
737737
if lora_config.max_cpu_loras is not None:
738738
peft_cache_config_model.num_host_module_layer = \
739739
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
740-
peft_cache_config = peft_cache_config_model._to_pybind()
741740

742741
from tensorrt_llm.bindings import WorldConfig
743742
world_config = WorldConfig(
@@ -748,7 +747,7 @@ def create_py_executor_instance(
748747
gpus_per_node=dist.mapping.gpus_per_node,
749748
)
750749
peft_cache_manager = PeftCacheManager(
751-
peft_cache_config=peft_cache_config,
750+
peft_cache_config=peft_cache_config_model,
752751
lora_config=lora_config,
753752
model_config=model_binding_config,
754753
world_config=world_config,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
2828
FinishReason, InflightBatchingStats,
2929
IterationStats, KvCacheStats,
30-
PeftCacheConfig, RequestStage,
31-
RequestStats, SpecDecodingStats,
30+
RequestStage, RequestStats,
31+
SpecDecodingStats,
3232
StaticBatchingStats)
3333
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
3434
ReqIdsSet)
35+
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
3536
from tensorrt_llm.logger import logger
3637
from tensorrt_llm.mapping import CpType
3738
from tensorrt_llm.runtime.generation import CUASSERT

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
from contextlib import contextmanager
77
from dataclasses import dataclass
88
from itertools import chain
9-
from typing import Optional
9+
from typing import Optional, Tuple
1010

1111
import torch
12+
from strenum import StrEnum
1213

1314
import tensorrt_llm
1415
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1516
from tensorrt_llm._utils import get_sm_version, mpi_disabled
16-
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
17-
ContextChunkingPolicy,
18-
GuidedDecodingConfig)
19-
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
20-
from tensorrt_llm.llmapi.llm_args import LoadFormat, PybindMirror, TorchLlmArgs
17+
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
18+
from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy,
19+
ContextChunkingPolicy, LoadFormat,
20+
PybindMirror, TorchLlmArgs)
2121
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase,
2222
_llguidance_tokenizer_info,
2323
_xgrammar_tokenizer_info)
@@ -214,12 +214,11 @@ def create_py_executor(
214214
if pytorch_backend_config is None:
215215
pytorch_backend_config = PyTorchConfig()
216216

217-
scheduler_config = PybindMirror.maybe_to_pybind(llm_args.scheduler_config)
217+
scheduler_config = llm_args.scheduler_config
218218

219-
peft_cache_config = None
220-
if llm_args.peft_cache_config is not None:
221-
peft_cache_config = PybindMirror.maybe_to_pybind(
222-
llm_args.peft_cache_config)
219+
# Since peft_cache_config may be subject to change, avoid these changes propagate back
220+
# to llm_args.peft_cache_config
221+
peft_cache_config = copy.deepcopy(llm_args.peft_cache_config)
223222

224223
assert llm_args.kv_cache_config, "Expect llm_args.kv_cache_config is not None"
225224
kv_cache_config = llm_args.kv_cache_config
@@ -457,8 +456,8 @@ def drafting_loop_wrapper(model):
457456
scheduler_config.context_chunking_policy is not None
458457
else ContextChunkingPolicy.FIRST_COME_FIRST_SERVED)
459458
assert chunk_unit_size is not None, "chunk_unit_size must be set"
460-
ctx_chunk_config = ContextChunkingConfig(chunking_policy,
461-
chunk_unit_size)
459+
ctx_chunk_config: Tuple[StrEnum,
460+
int] = (chunking_policy, chunk_unit_size)
462461
else:
463462
ctx_chunk_config = None
464463

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import tensorrt_llm.bindings
1212
from tensorrt_llm._utils import mpi_disabled
1313
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
14-
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PybindMirror
14+
from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig,
15+
PybindMirror)
1516
from tensorrt_llm.lora_helper import LoraConfig
1617
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
1718
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
@@ -39,7 +40,6 @@
3940
KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager
4041
RequestList = list[LlmRequest]
4142
PeftCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.PeftCacheManager
42-
PeftCacheConfig = tensorrt_llm.bindings.executor.PeftCacheConfig
4343
WorldConfig = tensorrt_llm.bindings.WorldConfig
4444
TempAttentionWindowInputs = tensorrt_llm.bindings.internal.batch_manager.TempAttentionWindowInputs
4545
BlocksPerWindow = Dict[int, Tuple[
@@ -1164,6 +1164,8 @@ def __init__(self,
11641164
world_config: WorldConfig | None = None):
11651165
import tensorrt_llm.bindings as _tb
11661166

1167+
peft_cache_config = peft_cache_config._to_pybind()
1168+
11671169
peft_cache_manager_config = _tb.PeftCacheManagerConfig(
11681170
num_host_module_layer=peft_cache_config.num_host_module_layer,
11691171
num_device_module_layer=peft_cache_config.num_device_module_layer,

tensorrt_llm/_torch/pyexecutor/scheduler.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from abc import ABC, abstractmethod
22
from collections import namedtuple
3-
from typing import Optional
3+
from typing import Optional, Tuple
4+
5+
from strenum import StrEnum
46

5-
from tensorrt_llm.bindings import executor as tb_executor
67
from tensorrt_llm.bindings import internal as tb_internal
8+
from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy
79

810
from .llm_request import LlmRequest, LlmRequestState
911

@@ -74,8 +76,8 @@ def __init__(
7476
max_num_requests: int,
7577
kv_cache_manager,
7678
peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None,
77-
scheduler_policy: tb_executor.CapacitySchedulerPolicy = tb_executor.
78-
CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
79+
scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy.
80+
GUARANTEED_NO_EVICT,
7981
two_step_lookahead: bool = False,
8082
):
8183
super(BindCapacityScheduler, self).__init__()
@@ -84,7 +86,7 @@ def __init__(
8486

8587
self.impl = tb_internal.algorithms.CapacityScheduler(
8688
max_num_requests=max_num_requests,
87-
capacity_scheduler_policy=scheduler_policy,
89+
capacity_scheduler_policy=scheduler_policy._to_pybind(),
8890
has_kv_cache_manager=kv_cache_manager is not None,
8991
two_step_lookahead=two_step_lookahead,
9092
no_schedule_until_state=LlmRequestState.CONTEXT_INIT,
@@ -172,14 +174,19 @@ def __init__(
172174
self,
173175
max_batch_size: int,
174176
max_num_tokens: int = None,
175-
ctx_chunk_config: Optional[
176-
tb_internal.batch_manager.ContextChunkingConfig] = None,
177+
ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None,
177178
) -> None:
178179
super(BindMicroBatchScheduler, self).__init__()
179180
self.max_batch_size = max_batch_size
180181
self.max_num_tokens = max_num_tokens
182+
183+
ctx_chunk_config_cpp = None
184+
if ctx_chunk_config is not None:
185+
ctx_chunk_config_cpp = tb_internal.batch_manager.ContextChunkingConfig(
186+
ctx_chunk_config[0]._to_pybind(), ctx_chunk_config[1])
187+
181188
self.impl = tb_internal.algorithms.MicroBatchScheduler(
182-
ctx_chunk_config, max_num_tokens)
189+
ctx_chunk_config_cpp, max_num_tokens)
183190

184191
def schedule(
185192
self, active_requests: RequestList, inflight_request_ids: set[int]

tests/unittest/_torch/executor/test_resource_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
import tensorrt_llm.bindings
1414
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
1515
from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager,
16-
PeftCacheConfig,
1716
PeftCacheManager)
1817
from tensorrt_llm.bindings import LayerType
1918
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
2019
from tensorrt_llm.bindings import executor as tllm
2120
from tensorrt_llm.bindings.internal.batch_manager import \
2221
PeftTaskNotCachedException
23-
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
22+
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig
2423
from tensorrt_llm.lora_helper import LoraConfig
2524
from tensorrt_llm.mapping import Mapping
2625

@@ -234,7 +233,7 @@ def create_peft_cache_config(self) -> PeftCacheConfig:
234233
num_ensure_workers=mock_config.ensure_thread_count,
235234
)
236235

237-
return peft_cache_config
236+
return PeftCacheConfig.from_pybind(peft_cache_config)
238237

239238
def _create_request(self,
240239
request_id,

0 commit comments

Comments
 (0)