@@ -69,28 +69,31 @@ def __init__(
6969 self ._await_response_helper = AwaitResponseHelper (
7070 self ) # TODO: make it weakref
7171
72- def _create_engine ():
73- device_id = self .global_rank % torch .cuda .device_count ()
74- torch .cuda .set_device (device_id )
75-
76- # Make sure C++ executor would use same devices/ranks as py_executor
77- global_rank = global_mpi_rank ()
78- comm_ranks = mpi_comm ().allgather (global_rank )
79- device_ids = mpi_comm ().allgather (device_id )
80- executor_config .parallel_config = tllm .ParallelConfig (
81- participant_ids = comm_ranks , device_ids = device_ids )
82-
83- if isinstance (engine , Engine ):
84- return tllm .Executor (engine .engine ,
85- json .dumps (engine .config .to_dict (),
86- cls = ConfigEncoder ),
87- tllm .ModelType .DECODER_ONLY ,
88- executor_config = executor_config ,
89- managed_weights = engine .managed_weights )
90-
91- if not hasattr (executor_config , "backend" ):
92- return tllm .Executor (engine , tllm .ModelType .DECODER_ONLY ,
93- executor_config )
72+ def create_engine (self , engine : Union [Path , Engine ],
73+ executor_config : tllm .ExecutorConfig ,
74+ lora_config : Optional [LoraConfig ],
75+ garbage_collection_gen0_threshold : Optional [int ]) -> None :
76+ device_id = self .global_rank % torch .cuda .device_count ()
77+ torch .cuda .set_device (device_id )
78+
79+ # Make sure C++ executor would use same devices/ranks as py_executor
80+ global_rank = global_mpi_rank ()
81+ comm_ranks = mpi_comm ().allgather (global_rank )
82+ device_ids = mpi_comm ().allgather (device_id )
83+ executor_config .parallel_config = tllm .ParallelConfig (
84+ participant_ids = comm_ranks , device_ids = device_ids )
85+
86+ if isinstance (engine , Engine ):
87+ self .engine = tllm .Executor (engine .engine ,
88+ json .dumps (engine .config .to_dict (),
89+ cls = ConfigEncoder ),
90+ tllm .ModelType .DECODER_ONLY ,
91+ executor_config = executor_config ,
92+ managed_weights = engine .managed_weights )
93+ elif not hasattr (executor_config , "backend" ):
94+ self .engine = tllm .Executor (engine , tllm .ModelType .DECODER_ONLY ,
95+ executor_config )
96+ else :
9497 args = {
9598 "executor_config" : executor_config ,
9699 "checkpoint_dir" : executor_config .hf_model_dir ,
@@ -109,10 +112,9 @@ def _create_engine():
109112 else :
110113 raise ValueError (
111114 f"Unsupported backend config: { executor_config .backend } " )
112- return create_executor (** args )
113-
114- self .engine = _create_engine ()
115+ self .engine = create_executor (** args )
115116
117+ # LoRA setup
116118 self ._lora_manager : Optional [LoraManager ] = None
117119 self ._prompt_adapter_manager : Optional [PromptAdapterManager ] = None
118120 self ._runtime_model_config : Optional [ModelConfig ] = None
0 commit comments