Skip to content

Commit ddb8722

Browse files
committed
fix rpc unique addr
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
1 parent ee588a7 commit ddb8722

File tree

6 files changed

+73
-58
lines changed

6 files changed

+73
-58
lines changed
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from .rpc_client import RPCClient
22
from .rpc_common import (RPCCancelled, RPCError, RPCParams, RPCRequest,
3-
RPCResponse, RPCStreamingError, RPCTimeout)
3+
RPCResponse, RPCStreamingError, RPCTimeout,
4+
get_unique_ipc_addr)
45
from .rpc_server import RPCServer, Server
56

67
__all__ = [
78
"RPCClient", "RPCServer", "Server", "RPCError", "RPCTimeout",
89
"RPCCancelled", "RPCStreamingError", "RPCRequest", "RPCResponse",
9-
"RPCParams"
10+
"RPCParams", "get_unique_ipc_addr"
1011
]

tensorrt_llm/executor/rpc/rpc_common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
1+
import os
12
import time
23
from dataclasses import dataclass
34
from typing import Any, Literal, NamedTuple, Optional
45

56

7+
def get_unique_ipc_addr() -> str:
8+
"""Generate a unique IPC address to avoid conflicts between tests.
9+
10+
Returns:
11+
A unique IPC address string in the format: ipc:///tmp/rpc_test_{pid}_{timestamp}
12+
"""
13+
pid = os.getpid()
14+
timestamp = int(time.time() * 1000000) # microseconds for better uniqueness
15+
return f"ipc:///tmp/rpc_test_{pid}_{timestamp}"
16+
17+
618
class RPCParams(NamedTuple):
719
""" Parameters for RPC calls. """
820

tensorrt_llm/executor/rpc_proxy.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import atexit
33
import json
4-
import os
54
import threading
65
from typing import Optional
76

@@ -16,6 +15,7 @@
1615
from .request import GenerationRequest
1716
from .result import GenerationResult
1817
from .rpc import RPCClient
18+
from .rpc.rpc_common import get_unique_ipc_addr
1919
from .rpc_worker import RpcWorker
2020
from .utils import (ErrorResponse, create_mpi_comm_session,
2121
get_spawn_proxy_process_env, is_llm_response)
@@ -45,7 +45,7 @@ def __init__(
4545
kv_connector_config: the kv cache connector config
4646
"""
4747
GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1
48-
self.rpc_addr = self.gen_uniq_rpc_addr()
48+
self.rpc_addr = get_unique_ipc_addr()
4949
self.rpc_client = RPCClient(self.rpc_addr)
5050

5151
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
@@ -368,8 +368,3 @@ def _create_mpi_session(self, model_world_size: int,
368368
else:
369369
print_colored_debug('using external mpi session ...\n', "yellow")
370370
self.mpi_session = mpi_session
371-
372-
@staticmethod
373-
def gen_uniq_rpc_addr() -> str:
374-
process_id = os.getpid()
375-
return f"ipc:///tmp/rpc-proxy-{process_id}-{GenerationExecutorRpcProxy.INSTANCE_COUNTER}"

tests/unittest/executor/test_rpc.py

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from tensorrt_llm.executor.rpc import (RPCCancelled, RPCClient, RPCError,
77
RPCServer, RPCStreamingError, RPCTimeout)
8+
from tensorrt_llm.executor.rpc.rpc_common import get_unique_ipc_addr
89

910

1011
class RpcServerWrapper(RPCServer):
@@ -31,7 +32,8 @@ class App:
3132
def hello(self):
3233
print("hello")
3334

34-
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
35+
addr = get_unique_ipc_addr()
36+
with RpcServerWrapper(App(), addr=addr) as server:
3537
pass
3638

3739
def test_remote_call_without_arg(self):
@@ -42,8 +44,9 @@ def hello(self):
4244
print("hello")
4345
return "world"
4446

45-
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
46-
with RPCClient("ipc:///tmp/rpc_test") as client:
47+
addr = get_unique_ipc_addr()
48+
with RpcServerWrapper(App(), addr=addr) as server:
49+
with RPCClient(addr) as client:
4750
ret = client.hello().remote() # sync call
4851
assert ret == "world"
4952

@@ -55,8 +58,9 @@ def hello(self, name: str, location: str):
5558
print("hello")
5659
return f"hello {name} from {location}"
5760

58-
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
59-
with RPCClient("ipc:///tmp/rpc_test") as client:
61+
addr = get_unique_ipc_addr()
62+
with RpcServerWrapper(App(), addr=addr) as server:
63+
with RPCClient(addr) as client:
6064
ret = client.hello("app", "Marvel").remote()
6165
assert ret == "hello app from Marvel"
6266

@@ -68,8 +72,9 @@ def hello(self, name: str, location: str):
6872
print("hello")
6973
return f"hello {name} from {location}"
7074

71-
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
72-
with RPCClient("ipc:///tmp/rpc_test") as client:
75+
addr = get_unique_ipc_addr()
76+
with RpcServerWrapper(App(), addr=addr) as server:
77+
with RPCClient(addr) as client:
7378
ret = client.hello(name="app", location="Marvel").remote()
7479
assert ret == "hello app from Marvel"
7580

@@ -81,8 +86,9 @@ def hello(self, name: str, location: str):
8186
print("hello")
8287
return f"hello {name} from {location}"
8388

84-
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
85-
with RPCClient("ipc:///tmp/rpc_test") as client:
89+
addr = get_unique_ipc_addr()
90+
with RpcServerWrapper(App(), addr=addr) as server:
91+
with RPCClient(addr) as client:
8692
ret = client.hello(name="app", location="Marvel").remote()
8793
assert ret == "hello app from Marvel"
8894

@@ -91,8 +97,9 @@ def test_rpc_server_address(self):
9197
class App:
9298
pass
9399

94-
with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server:
95-
assert server.address == "ipc:///tmp/rpc_test"
100+
addr = get_unique_ipc_addr()
101+
with RpcServerWrapper(App(), addr=addr) as server:
102+
assert server.address == addr
96103

97104
def test_rpc_with_error(self):
98105

@@ -101,9 +108,9 @@ class App:
101108
def hello(self):
102109
raise ValueError("hello")
103110

104-
with RpcServerWrapper(App(),
105-
addr="ipc:///tmp/rpc_test_error") as server:
106-
with RPCClient("ipc:///tmp/rpc_test_error") as client:
111+
addr = get_unique_ipc_addr()
112+
with RpcServerWrapper(App(), addr=addr) as server:
113+
with RPCClient(addr) as client:
107114
with pytest.raises(RPCError):
108115
client.hello().remote()
109116

@@ -123,9 +130,9 @@ def send_task(self) -> None:
123130
def get_task_submitted(self) -> bool:
124131
return self.task_submitted
125132

126-
with RpcServerWrapper(App(),
127-
addr="ipc:///tmp/rpc_test_no_wait") as server:
128-
with RPCClient("ipc:///tmp/rpc_test_no_wait") as client:
133+
addr = get_unique_ipc_addr()
134+
with RpcServerWrapper(App(), addr=addr) as server:
135+
with RPCClient(addr) as client:
129136
client.send_task().remote(need_response=False)
130137
time.sleep(
131138
0.1
@@ -152,11 +159,12 @@ def divide_by_zero(self):
152159
def custom_exception(self):
153160
raise TestRpcError.CustomError("Custom error occurred")
154161

162+
addr = get_unique_ipc_addr()
155163
with RPCServer(App()) as server:
156-
server.bind("ipc:///tmp/rpc_test_error")
164+
server.bind(addr)
157165
server.start()
158166
time.sleep(0.1)
159-
with RPCClient("ipc:///tmp/rpc_test_error") as client:
167+
with RPCClient(addr) as client:
160168
# Test ValueError handling
161169
with pytest.raises(RPCError) as exc_info:
162170
client.hello().remote()
@@ -196,7 +204,7 @@ def task(self):
196204
time.sleep(10)
197205
return True
198206

199-
addr = "ipc:///tmp/rpc_test_cancelled"
207+
addr = get_unique_ipc_addr()
200208

201209
server = RPCServer(
202210
App(),
@@ -218,7 +226,6 @@ def task(self):
218226

219227
client.close()
220228

221-
@pytest.mark.skip(reason="https://nvbugs/5579234")
222229
def test_timeout_error(self):
223230
"""Test that requests that exceed timeout are handled with proper error."""
224231

@@ -229,13 +236,12 @@ def slow_method(self):
229236
time.sleep(2.0)
230237
return "completed"
231238

232-
with RpcServerWrapper(App(),
233-
addr="ipc:///tmp/rpc_test_timeout") as server:
239+
addr = get_unique_ipc_addr()
240+
with RpcServerWrapper(App(), addr=addr) as server:
234241
time.sleep(0.1)
235242

236243
# Create client with short timeout
237-
with RPCClient("ipc:///tmp/rpc_test_timeout",
238-
timeout=0.5) as client:
244+
with RPCClient(addr, timeout=0.5) as client:
239245
with pytest.raises(RPCError) as exc_info:
240246
client.slow_method().remote(timeout=0.5)
241247

@@ -252,11 +258,11 @@ class App:
252258
def existing_method(self):
253259
return "exists"
254260

255-
with RpcServerWrapper(App(),
256-
addr="ipc:///tmp/rpc_test_not_found") as server:
261+
addr = get_unique_ipc_addr()
262+
with RpcServerWrapper(App(), addr=addr) as server:
257263
time.sleep(0.1)
258264

259-
with RPCClient("ipc:///tmp/rpc_test_not_found") as client:
265+
with RPCClient(addr) as client:
260266
with pytest.raises(RPCError) as exc_info:
261267
client.non_existent_method().remote()
262268

@@ -272,11 +278,12 @@ class App:
272278
def hello(self):
273279
return "world"
274280

281+
addr = get_unique_ipc_addr()
275282
with RPCServer(App()) as server:
276-
server.bind("ipc:///tmp/rpc_test_shutdown")
283+
server.bind(addr)
277284
server.start()
278285
time.sleep(0.1)
279-
with RPCClient("ipc:///tmp/rpc_test_shutdown") as client:
286+
with RPCClient(addr) as client:
280287
ret = client.hello().remote()
281288
assert ret == "world"
282289

@@ -298,11 +305,12 @@ def send_task(self) -> None:
298305
time.sleep(0.001)
299306
return None
300307

308+
addr = get_unique_ipc_addr()
301309
with RPCServer(App(), num_workers=10) as server:
302-
server.bind("ipc:///tmp/rpc_test_no_wait")
310+
server.bind(addr)
303311
server.start()
304312
time.sleep(0.1)
305-
with RPCClient("ipc:///tmp/rpc_test_no_wait") as client:
313+
with RPCClient(addr) as client:
306314
time_start = time.time()
307315
for i in range(100):
308316
client.send_task().remote(need_response=False)
@@ -329,7 +337,7 @@ def cal(self, n: int):
329337
return n * 2
330338

331339
with RPCServer(App(), async_run_task=async_run_task) as server:
332-
address = "ipc:///tmp/rpc_test" if use_ipc_addr else "tcp://127.0.0.1:*"
340+
address = get_unique_ipc_addr() if use_ipc_addr else "tcp://127.0.0.1:*"
333341

334342
server.bind(address)
335343
server.start()
@@ -428,10 +436,10 @@ class App:
428436
def quick_task(self, task_id: int):
429437
return f"quick_task_{task_id}"
430438

431-
with RpcServerWrapper(App(),
432-
addr="ipc:///tmp/rpc_test_shutdown") as server:
439+
addr = get_unique_ipc_addr()
440+
with RpcServerWrapper(App(), addr=addr) as server:
433441
time.sleep(0.1)
434-
with RPCClient("ipc:///tmp/rpc_test_shutdown") as client:
442+
with RPCClient(addr) as client:
435443
client.quick_task(1).remote()
436444

437445
# repeated shutdown should not raise an error
@@ -446,12 +454,13 @@ def foo(self, delay: int):
446454
time.sleep(delay)
447455
return "foo"
448456

457+
addr = get_unique_ipc_addr()
449458
server = RPCServer(App())
450-
server.bind("ipc:///tmp/rpc_test_shutdown")
459+
server.bind(addr)
451460
server.start()
452461

453462
time.sleep(0.1)
454-
with RPCClient("ipc:///tmp/rpc_test_shutdown") as client:
463+
with RPCClient(addr) as client:
455464
# This task should be continued after server shutdown
456465
res = client.foo(10).remote_future(timeout=12)
457466

@@ -639,21 +648,20 @@ def nested_function():
639648
yield nested_function
640649

641650
def test_unpickleable_error(self):
642-
with RpcServerWrapper(
643-
self.App(), addr="ipc:///tmp/rpc_test_pickle_error") as server:
644-
with RPCClient("ipc:///tmp/rpc_test_pickle_error") as client:
651+
addr = get_unique_ipc_addr()
652+
with RpcServerWrapper(self.App(), addr=addr) as server:
653+
with RPCClient(addr) as client:
645654
with pytest.raises(RPCError) as exc_info:
646655
client.unpickleable_return().remote()
647656

648657
assert "Failed to pickle response" in str(exc_info.value)
649658

650659
@pytest.mark.asyncio
651660
async def test_unpickleable_streaming_error(self):
652-
with RpcServerWrapper(self.App(),
653-
addr="ipc:///tmp/rpc_test_pickle_error_streaming",
661+
addr = get_unique_ipc_addr()
662+
with RpcServerWrapper(self.App(), addr=addr,
654663
async_run_task=True) as server:
655-
with RPCClient(
656-
"ipc:///tmp/rpc_test_pickle_error_streaming") as client:
664+
with RPCClient(addr) as client:
657665
with pytest.raises(RPCStreamingError) as exc_info:
658666
async for _ in client.unpickleable_streaming_return(
659667
).remote_streaming():

tests/unittest/executor/test_rpc_proxy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def create_proxy(self, tp_size: int):
5555

5656
return proxy
5757

58-
@pytest.mark.skip(reason="https://nvbugs/5579234")
5958
@pytest.mark.parametrize("num_reqs", [1, 10])
6059
def test_tp1(self, num_reqs):
6160
tokenizer = TransformersTokenizer.from_pretrained(model_path)

tests/unittest/executor/test_rpc_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from tensorrt_llm.executor.request import GenerationRequest
1212
from tensorrt_llm.executor.rpc import RPCClient
13-
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
13+
from tensorrt_llm.executor.rpc.rpc_common import get_unique_ipc_addr
1414
from tensorrt_llm.executor.rpc_worker import RpcWorker
1515
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
1616
from tensorrt_llm.sampling_params import SamplingParams
@@ -42,7 +42,7 @@ def teardown_method(self):
4242
self.client.close()
4343

4444
def create_worker_pool(self):
45-
addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr()
45+
addr = get_unique_ipc_addr()
4646
mp_context = multiprocessing.get_context(
4747
'spawn') # spawn for CUDA context
4848
pool = ProcessPoolExecutor(max_workers=1, mp_context=mp_context)
@@ -211,7 +211,7 @@ def teardown_method(self):
211211

212212
def create_worker_session(self):
213213
session = MpiPoolSession(n_workers=2)
214-
addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr()
214+
addr = get_unique_ipc_addr()
215215
futures = session.submit(RpcWorker.main_task,
216216
engine=model_path,
217217
rpc_addr=addr,

0 commit comments

Comments
 (0)