Skip to content

Commit d4ddb84

Browse files
committed
add torch.distributed + rpc orchestrator
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
1 parent 5b5e596 commit d4ddb84

File tree

5 files changed

+408
-14
lines changed

5 files changed

+408
-14
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -476,14 +476,22 @@ def _get_cluster_info(self):
476476
if self.cluster_info is not None:
477477
return self.cluster_info
478478

479-
if ray.is_initialized():
479+
is_ray_initialized = False
480+
try:
481+
if ray.is_initialized():
482+
is_ray_initialized = True
483+
except Exception:
484+
pass
485+
486+
if is_ray_initialized:
480487
node_ip = ray.util.get_node_ip_address()
488+
gpu_index = [int(id) for id in ray.get_gpu_ids()]
489+
assert len(gpu_index) == 1
490+
gpu_id = gpu_index[0]
481491
else:
482-
raise RuntimeError("Ray is not initialized")
483-
484-
gpu_index = [int(id) for id in ray.get_gpu_ids()]
485-
486-
assert len(gpu_index) == 1
492+
import socket
493+
node_ip = socket.gethostbyname(socket.gethostname())
494+
gpu_id = torch.cuda.current_device()
487495

488496
# Gather node ip
489497
node_list = [None] * torch.distributed.get_world_size()
@@ -492,7 +500,7 @@ def _get_cluster_info(self):
492500

493501
# Gather gpu index
494502
gpu_list = [None] * torch.distributed.get_world_size()
495-
torch.distributed.all_gather_object(gpu_list, gpu_index[0])
503+
torch.distributed.all_gather_object(gpu_list, gpu_id)
496504

497505
# Gather rank
498506
rank_list = [None] * torch.distributed.get_world_size()
@@ -639,8 +647,15 @@ def allreduce(self,
639647
obj: int | float | torch.Tensor,
640648
op=torch.distributed.ReduceOp.SUM):
641649
is_base_type = isinstance(obj, int) or isinstance(obj, float)
650+
device = torch.device(
651+
"cuda") if dist.get_backend() == "nccl" else torch.device("cpu")
652+
642653
if is_base_type:
643-
obj = torch.tensor(obj)
654+
obj = torch.tensor(obj, device=device)
655+
elif isinstance(obj, torch.Tensor):
656+
# Ensure tensor is on the correct device
657+
if obj.device != device:
658+
obj = obj.to(device)
644659

645660
dist.all_reduce(obj, op=op)
646661

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import multiprocessing
2+
from pathlib import Path
3+
from typing import Any, Dict, List, Optional, Union
4+
5+
from tensorrt_llm._utils import get_free_port
6+
from tensorrt_llm.bindings import executor as tllm
7+
from tensorrt_llm.builder import Engine
8+
from tensorrt_llm.executor.executor import GenerationExecutor
9+
from tensorrt_llm.executor.postproc_worker import PostprocWorkerConfig
10+
from tensorrt_llm.executor.rpc_proxy_mixin import RpcExecutorMixin
11+
from tensorrt_llm.executor.rpc_torch_dist_worker import RpcTorchDistWorker
12+
from tensorrt_llm.llmapi.llm_args import BaseLlmArgs
13+
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
14+
from tensorrt_llm.logger import logger
15+
from tensorrt_llm.sampling_params import BatchedLogitsProcessor
16+
17+
18+
class RpcTorchDistExecutor(RpcExecutorMixin, GenerationExecutor):
19+
def __init__(
20+
self,
21+
worker_kwargs: Dict,
22+
model_world_size: int,
23+
postproc_worker_config: PostprocWorkerConfig,
24+
is_llm_executor: bool,
25+
):
26+
# Initialize GenerationExecutor
27+
super().__init__(
28+
num_postprocess_workers=postproc_worker_config.num_postprocess_workers,
29+
postprocess_tokenizer_dir=postproc_worker_config.postprocess_tokenizer_dir,
30+
is_llm_executor=is_llm_executor,
31+
)
32+
33+
self.world_size = model_world_size
34+
self.processes: List[multiprocessing.Process] = []
35+
36+
# Setup RPC
37+
self.init_rpc_executor()
38+
39+
# Determine Master Addr/Port for torch.distributed
40+
self.master_addr = "127.0.0.1"
41+
self.master_port = str(get_free_port())
42+
43+
logger.info(
44+
f"RpcTorchDistExecutor starting with {model_world_size} workers."
45+
f"Master: {self.master_addr}:{self.master_port}"
46+
)
47+
48+
# Spawn workers
49+
self.start_workers(worker_kwargs)
50+
51+
# Setup engine (remote)
52+
# This will trigger setup_engine on rank 0 via RPC, which broadcasts to other ranks
53+
try:
54+
logger.info("Setting up remote engine...")
55+
self.setup_engine_remote()
56+
except Exception as e:
57+
logger.error(f"Failed to setup remote engine: {e}")
58+
self.shutdown()
59+
raise e
60+
61+
# Setup main loop for receiving results from RPC
62+
self.setup_mainloop()
63+
64+
def start_workers(self, worker_kwargs: Dict):
65+
ctx = multiprocessing.get_context("spawn")
66+
67+
for rank in range(self.world_size):
68+
p = ctx.Process(
69+
target=RpcTorchDistWorker.worker_main,
70+
args=(
71+
rank,
72+
self.world_size,
73+
self.master_addr,
74+
self.master_port,
75+
self.rpc_addr, # Passed to all, but only used by rank 0
76+
worker_kwargs,
77+
),
78+
name=f"RpcTorchDistWorker-{rank}",
79+
)
80+
p.start()
81+
self.processes.append(p)
82+
83+
def setup_engine_remote(self):
84+
# Call setup_engine on Rank 0 via RPC
85+
# We wait for the result to ensure everything is initialized
86+
self.rpc_client.setup_engine().remote()
87+
88+
def shutdown(self):
89+
if self.doing_shutdown:
90+
return
91+
self.doing_shutdown = True
92+
93+
logger.info("Shutting down RpcTorchDistExecutor...")
94+
95+
# RPC shutdown to Rank 0
96+
try:
97+
if hasattr(self, "rpc_client") and self.rpc_client:
98+
# This tells Rank 0 to shutdown, which broadcasts shutdown to others
99+
self.rpc_client.shutdown().remote(need_response=False)
100+
except Exception as e:
101+
logger.warning(f"Error during RPC shutdown: {e}")
102+
103+
# Cleanup RPC client
104+
if hasattr(self, "rpc_client") and self.rpc_client:
105+
self.rpc_client.close()
106+
107+
# Join processes
108+
for p in self.processes:
109+
if p.is_alive():
110+
p.join(timeout=5)
111+
if p.is_alive():
112+
logger.warning(f"Process {p.name} did not exit, terminating...")
113+
p.terminate()
114+
115+
super().shutdown()
116+
117+
@classmethod
118+
def create(
119+
cls,
120+
engine: Union[Path, Engine],
121+
executor_config: Optional[tllm.ExecutorConfig] = None,
122+
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
123+
model_world_size: int = 1,
124+
mpi_session: Optional[Any] = None,
125+
reuse_mpi_comm: bool = False,
126+
return_logits: bool = False,
127+
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
128+
is_llm_executor: Optional[bool] = None,
129+
hf_model_dir: Optional[Path] = None,
130+
tokenizer: Optional[TokenizerBase] = None,
131+
llm_args: Optional[BaseLlmArgs] = None,
132+
**kwargs,
133+
):
134+
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig()
135+
136+
worker_kwargs = {
137+
"engine": engine,
138+
"executor_config": executor_config,
139+
"batched_logits_processor": batched_logits_processor,
140+
"hf_model_dir": hf_model_dir,
141+
"tokenizer": tokenizer,
142+
"llm_args": llm_args,
143+
}
144+
145+
return cls(
146+
worker_kwargs=worker_kwargs,
147+
model_world_size=model_world_size,
148+
postproc_worker_config=postproc_worker_config,
149+
is_llm_executor=is_llm_executor or False,
150+
)
151+
152+
# Implement abstract methods from GenerationExecutor
153+
def submit(self, request):
154+
return super().submit(request) # RpcExecutorMixin.submit
155+
156+
def abort_request(self, request_id: int):
157+
# Forward to Rank 0
158+
self.rpc_client.abort_request(request_id).remote(need_response=False)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import os
2+
from typing import Any, Dict, Optional
3+
4+
import torch
5+
import torch.distributed as dist
6+
7+
from tensorrt_llm.executor.base_worker import BaseWorker
8+
from tensorrt_llm.executor.rpc_worker_mixin import RpcWorkerMixin
9+
from tensorrt_llm.logger import logger
10+
11+
12+
class RpcTorchDistWorker(RpcWorkerMixin, BaseWorker):
13+
def __init__(
14+
self, rank: int, world_size: int, device_id: int, rpc_addr: Optional[str] = None, **kwargs
15+
):
16+
# Initialize BaseWorker
17+
super().__init__(**kwargs)
18+
19+
self.rank = rank
20+
self.global_rank = rank
21+
self.world_size = world_size
22+
self.device_id = device_id
23+
24+
# Create control group for worker orchestration
25+
# Use Gloo for control messages as it doesn't require GPU
26+
# and is robust.
27+
self.control_group = dist.new_group(backend="gloo")
28+
29+
if self.rank == 0:
30+
if rpc_addr is None:
31+
raise ValueError("rpc_addr must be provided for rank 0")
32+
self.init_rpc_worker(self.rank, rpc_addr)
33+
self.start_rpc_server()
34+
35+
def setup_engine(self):
36+
# Broadcast command if rank 0
37+
if self.rank == 0:
38+
self._broadcast_command("setup_engine")
39+
40+
# Ensure we are synchronized before setting up engine if needed
41+
if dist.is_initialized():
42+
dist.barrier()
43+
44+
super().setup_engine()
45+
46+
def start(self):
47+
pass
48+
49+
def shutdown(self):
50+
if self.doing_shutdown:
51+
return
52+
53+
# Broadcast command if rank 0
54+
if self.rank == 0:
55+
try:
56+
self._broadcast_command("shutdown")
57+
except Exception as e:
58+
logger.warning(f"Failed to broadcast shutdown command: {e}")
59+
60+
super().shutdown()
61+
62+
if self.rank == 0 and hasattr(self, "rpc_server") and self.rpc_server:
63+
self.rpc_server.shutdown()
64+
65+
def _broadcast_command(self, command: str, args: Any = None):
66+
if not dist.is_initialized():
67+
return
68+
cmd_list = [command, args]
69+
try:
70+
dist.broadcast_object_list(cmd_list, src=0, group=self.control_group)
71+
except Exception as e:
72+
logger.error(f"Broadcast error: {e}")
73+
74+
@classmethod
75+
def worker_main(
76+
cls,
77+
rank: int,
78+
world_size: int,
79+
master_addr: str,
80+
master_port: str,
81+
rpc_addr: Optional[str],
82+
worker_kwargs: Dict,
83+
):
84+
# Setup environment
85+
os.environ["MASTER_ADDR"] = master_addr
86+
os.environ["MASTER_PORT"] = master_port
87+
os.environ["RANK"] = str(rank)
88+
os.environ["WORLD_SIZE"] = str(world_size)
89+
os.environ["TLLM_DISABLE_MPI"] = "1"
90+
91+
# Setup device
92+
if torch.cuda.is_available():
93+
device_id = rank % torch.cuda.device_count()
94+
torch.cuda.set_device(device_id)
95+
else:
96+
device_id = 0
97+
98+
# Initialize process group
99+
# Use nccl for GPU, gloo for CPU
100+
backend = "nccl" if torch.cuda.is_available() else "gloo"
101+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
102+
103+
logger.info(f"Worker {rank}/{world_size} initialized with backend {backend}")
104+
105+
try:
106+
worker = cls(
107+
rank=rank,
108+
world_size=world_size,
109+
device_id=device_id,
110+
rpc_addr=rpc_addr,
111+
**worker_kwargs,
112+
)
113+
114+
if rank == 0:
115+
# Rank 0 waits for RPCs.
116+
# The RPC server runs in a background thread started by start_rpc_server.
117+
# We wait on the shutdown event which is set by shutdown() method (called via RPC).
118+
worker.shutdown_event.wait()
119+
else:
120+
# Rank > 0 command loop
121+
while True:
122+
cmd_list = [None, None]
123+
try:
124+
dist.broadcast_object_list(cmd_list, src=0, group=worker.control_group)
125+
except Exception as e:
126+
# If broadcast fails (e.g. rank 0 died), we should exit
127+
logger.error(f"Rank {rank} broadcast receive error: {e}")
128+
break
129+
130+
cmd, args = cmd_list
131+
# logger.debug(f"Rank {rank} received command: {cmd}")
132+
133+
if cmd == "setup_engine":
134+
worker.setup_engine()
135+
elif cmd == "shutdown":
136+
worker.shutdown()
137+
break
138+
elif cmd == "report_device_id":
139+
# Optional: handle other commands if needed
140+
pass
141+
else:
142+
logger.warning(f"Rank {rank} received unknown command: {cmd}")
143+
144+
except Exception as e:
145+
logger.error(f"Worker {rank} failed with error: {e}")
146+
raise e
147+
finally:
148+
if dist.is_initialized():
149+
dist.destroy_process_group()

tensorrt_llm/llmapi/llm_args.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,12 +1852,13 @@ class BaseLlmArgs(StrictBaseModel):
18521852
description="Return perf metrics.",
18531853
status="prototype")
18541854

1855-
orchestrator_type: Optional[Literal["rpc", "ray"]] = Field(
1856-
default=None,
1857-
description=
1858-
"The orchestrator type to use. Defaults to None, which uses MPI.",
1859-
status="prototype",
1860-
)
1855+
orchestrator_type: Optional[Literal[
1856+
"rpc", "ray", "rpc_torch_dist"]] = Field(
1857+
default=None,
1858+
description=
1859+
"The orchestrator type to use. Defaults to None, which uses MPI.",
1860+
status="prototype",
1861+
)
18611862

18621863
_parallel_config: Optional[_ParallelConfig] = PrivateAttr(default=None)
18631864
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)

0 commit comments

Comments
 (0)