Skip to content

Commit cbfa09e

Browse files
committed
refine WorkerBase interface
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
1 parent 6c709ce commit cbfa09e

File tree

2 files changed

+33
-25
lines changed

2 files changed

+33
-25
lines changed

tensorrt_llm/executor/worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ def __init__(
6868
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
6969
processor_batched=batched_logits_processor, replicate=False)
7070

71+
self.create_engine(
72+
engine=engine,
73+
executor_config=executor_config,
74+
lora_config=lora_config,
75+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
76+
7177
self.await_response_thread = ManagedThread(
7278
self.await_response_task,
7379
error_queue=self._error_queue,

tensorrt_llm/executor/worker_base.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,31 @@ def __init__(
6969
self._await_response_helper = AwaitResponseHelper(
7070
self) # TODO: make it weakref
7171

72-
def _create_engine():
73-
device_id = self.global_rank % torch.cuda.device_count()
74-
torch.cuda.set_device(device_id)
75-
76-
# Make sure C++ executor would use same devices/ranks as py_executor
77-
global_rank = global_mpi_rank()
78-
comm_ranks = mpi_comm().allgather(global_rank)
79-
device_ids = mpi_comm().allgather(device_id)
80-
executor_config.parallel_config = tllm.ParallelConfig(
81-
participant_ids=comm_ranks, device_ids=device_ids)
82-
83-
if isinstance(engine, Engine):
84-
return tllm.Executor(engine.engine,
85-
json.dumps(engine.config.to_dict(),
86-
cls=ConfigEncoder),
87-
tllm.ModelType.DECODER_ONLY,
88-
executor_config=executor_config,
89-
managed_weights=engine.managed_weights)
90-
91-
if not hasattr(executor_config, "backend"):
92-
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
93-
executor_config)
72+
def create_engine(self, engine: Union[Path, Engine],
73+
executor_config: tllm.ExecutorConfig,
74+
lora_config: Optional[LoraConfig],
75+
garbage_collection_gen0_threshold: Optional[int]) -> None:
76+
device_id = self.global_rank % torch.cuda.device_count()
77+
torch.cuda.set_device(device_id)
78+
79+
# Make sure C++ executor would use same devices/ranks as py_executor
80+
global_rank = global_mpi_rank()
81+
comm_ranks = mpi_comm().allgather(global_rank)
82+
device_ids = mpi_comm().allgather(device_id)
83+
executor_config.parallel_config = tllm.ParallelConfig(
84+
participant_ids=comm_ranks, device_ids=device_ids)
85+
86+
if isinstance(engine, Engine):
87+
self.engine = tllm.Executor(engine.engine,
88+
json.dumps(engine.config.to_dict(),
89+
cls=ConfigEncoder),
90+
tllm.ModelType.DECODER_ONLY,
91+
executor_config=executor_config,
92+
managed_weights=engine.managed_weights)
93+
elif not hasattr(executor_config, "backend"):
94+
self.engine = tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
95+
executor_config)
96+
else:
9497
args = {
9598
"executor_config": executor_config,
9699
"checkpoint_dir": executor_config.hf_model_dir,
@@ -109,10 +112,9 @@ def _create_engine():
109112
else:
110113
raise ValueError(
111114
f"Unsupported backend config: {executor_config.backend}")
112-
return create_executor(**args)
113-
114-
self.engine = _create_engine()
115+
self.engine = create_executor(**args)
115116

117+
# LoRA setup
116118
self._lora_manager: Optional[LoraManager] = None
117119
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
118120
self._runtime_model_config: Optional[ModelConfig] = None

0 commit comments

Comments
 (0)