Skip to content

Commit 71be18b

Browse files
committed
add tests defines
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent 93ab23c commit 71be18b

File tree

8 files changed

+269
-235
lines changed

8 files changed

+269
-235
lines changed

tensorrt_llm/llmapi/disagg_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ class ConditionalDisaggConfig():
4343
max_local_prefill_length: int = 0
4444

4545

46+
@dataclass
47+
class MinimalInstances:
48+
context_servers: int = 1
49+
generation_servers: int = 1
50+
51+
52+
@dataclass
53+
class DisaggClusterConfig:
54+
cluster_uri: str
55+
cluster_name: str = ""
56+
minimal_instances: Optional[MinimalInstances] = None
57+
heartbeat_interval: int = 5
58+
inactive_timeout: int = 10
59+
60+
4661
@dataclass
4762
class DisaggServerConfig():
4863
server_configs: List[CtxGenServerConfig]

tensorrt_llm/serve/auto_scaling.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,6 @@ async def get_worker_events(
105105
worker_events = []
106106
for event in events:
107107
try:
108-
print(
109-
f"Processing event: {event.event_type} for key: {event.storage_item.key} value {event.storage_item.value}"
110-
)
111108
worker_info = self._parse_worker_info(event)
112109
worker_events.append((worker_info, event.event_type))
113110
except Exception as e:
@@ -192,6 +189,11 @@ def __init__(self, role: ServerRole, host: str, port: int,
192189
self._last_heartbeat = 0
193190
self._worker_id = f"{role.name}-{host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
194191

192+
def __del__(self):
193+
if asyncio.get_event_loop():
194+
asyncio.run_coroutine_threadsafe(self.deregister_worker(),
195+
asyncio.get_event_loop())
196+
195197
@property
196198
def worker_id(self) -> str:
197199
return self._worker_id
@@ -230,8 +232,8 @@ async def register_worker(self, validator=None, retry_interval=5):
230232
logger.warning(
231233
f"Worker {self.worker_info.worker_id} registration failed, retry in {retry_interval} seconds"
232234
)
233-
await asyncio.sleep(max(10, retry_interval))
234-
return await self.register_worker(validator, retry_interval + 1)
235+
await asyncio.sleep(retry_interval)
236+
return await self.register_worker(validator, retry_interval)
235237
else:
236238
logger.info(
237239
f"Worker {self.worker_info.worker_id} registration successful")
@@ -248,8 +250,9 @@ async def register_worker(self, validator=None, retry_interval=5):
248250

249251
async def deregister_worker(self):
250252
self._stop = True
251-
self._heartbeat_task.cancel()
252-
self._heartbeat_task = None
253+
if self._heartbeat_task:
254+
self._heartbeat_task.cancel()
255+
self._heartbeat_task = None
253256
await self._cluster_storage.stop()
254257
success = await self._cluster_storage.delete(self.worker_key)
255258
if not success:

tensorrt_llm/serve/cluster_storage.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,12 @@ async def get_prefix(self,
9999
def create_cluster_storage(cluster_uri, cluster_name, **kwargs):
100100
if cluster_uri.startswith("http"):
101101
return HttpClusterStorageServer(cluster_uri, cluster_name, **kwargs)
102-
elif cluster_uri.startswith("etcd"):
103-
from tensorrt_llm.serve.cluster_storage_etcd import Etcd3ClusterStorage
104-
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
105102
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
106103

107104

108105
def create_cluster_storage_client(cluster_uri, cluster_name):
109106
if cluster_uri.startswith("http"):
110107
return HttpClusterStorageClient(cluster_uri, cluster_name)
111-
elif cluster_uri.startswith("etcd"):
112-
from tensorrt_llm.serve.cluster_storage_etcd import Etcd3ClusterStorage
113-
return Etcd3ClusterStorage(cluster_uri, cluster_name)
114108
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
115109

116110

@@ -356,7 +350,7 @@ async def get_prefix(self,
356350
keys_only: bool = False) -> Dict[str, str]:
357351
return await self._get("get_prefix",
358352
key_prefix=key_prefix,
359-
keys_only=keys_only)
353+
keys_only=int(keys_only))
360354

361355
async def delete(self, key: str) -> bool:
362356
try:

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ l0_h100:
3535
- unittest/disaggregated/test_disagg_utils.py
3636
- unittest/disaggregated/test_router.py
3737
- unittest/disaggregated/test_remoteDictionary.py
38+
- unittest/disaggregated/test_cluster_manager_worker.py
39+
- unittest/disaggregated/test_cluster_storage.py
3840
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
3941
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa
4042
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_chunked_prefill
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import asyncio
2+
import subprocess
3+
import tempfile
4+
import time
5+
6+
import pytest
7+
8+
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
9+
MinimalInstances, ServerRole)
10+
from tensorrt_llm.serve.auto_scaling import ClusterManager, ClusterWorker
11+
from tensorrt_llm.serve.cluster_storage import (WatchEventType,
12+
create_cluster_storage,
13+
create_cluster_storage_client)
14+
15+
from .test_cluster_storage import http_server_storage, pytest_async_fixture
16+
17+
INACTIVE_TIMEOUT = 4
18+
HEARTBEAT_INTERVAL = 2
19+
20+
storage_types = ["http"]
21+
22+
23+
def get_uri(storage_type):
24+
if storage_type == "http":
25+
return f"http://localhost:18000"
26+
elif storage_type == "etcd":
27+
return f"etcd://localhost:2379"
28+
else:
29+
raise ValueError(f"Invalid storage type: {storage_type}")
30+
31+
32+
@pytest.fixture(scope="module")
33+
def config(request):
34+
cluster_uri = get_uri(request.param)
35+
return DisaggClusterConfig(cluster_uri=cluster_uri,
36+
cluster_name="test",
37+
minimal_instances=MinimalInstances(
38+
context_servers=1, generation_servers=1),
39+
inactive_timeout=INACTIVE_TIMEOUT,
40+
heartbeat_interval=HEARTBEAT_INTERVAL)
41+
42+
43+
@pytest.fixture(scope="module")
44+
def storage_server(config):
45+
if config.cluster_uri.startswith("http"):
46+
port = 18000
47+
server, cluster_storage = http_server_storage(port)
48+
with server.run_in_thread():
49+
yield cluster_storage, config.cluster_uri
50+
elif config.cluster_uri.startswith("etcd"):
51+
with tempfile.TemporaryDirectory() as temp_dir:
52+
etcd = subprocess.Popen(
53+
["etcd", "--data-dir", temp_dir, "--log-level", "debug"])
54+
time.sleep(2) # wait for etcd to start
55+
yield create_cluster_storage(
56+
config.cluster_uri, config.cluster_name), config.cluster_uri
57+
etcd.kill()
58+
etcd.wait()
59+
else:
60+
raise ValueError(f"Invalid cluster storage URI: {config.cluster_uri}")
61+
62+
63+
@pytest_async_fixture(scope="module")
64+
async def storage_client(storage_server):
65+
_, cluster_uri = storage_server
66+
return create_cluster_storage_client(cluster_uri, "test")
67+
68+
69+
@pytest_async_fixture(scope="module")
70+
async def cluster_manager(config, storage_server):
71+
storage, cluster_uri = storage_server
72+
manager = ClusterManager(config, storage)
73+
await manager.start()
74+
yield manager
75+
await manager.stop()
76+
77+
78+
@pytest.mark.parametrize("config", storage_types, indirect=True)
79+
@pytest.mark.threadleak(enabled=False)
80+
@pytest.mark.asyncio(scope="module")
81+
async def test_init_workers_first(config, storage_server):
82+
try:
83+
# init workers before initializing the manager, so the manager should be able to
84+
# get the pre-registered workers
85+
server, storage_uri = storage_server
86+
storage_client = create_cluster_storage_client(storage_uri, "test")
87+
ctx_worker = ClusterWorker(ServerRole.CONTEXT, "127.0.0.1", 8001,
88+
config, storage_client)
89+
gen_worker = ClusterWorker(ServerRole.GENERATION, "127.0.0.1", 8002,
90+
config, storage_client)
91+
await ctx_worker.register_worker()
92+
await gen_worker.register_worker()
93+
94+
cluster_manager = ClusterManager(config, server)
95+
await cluster_manager.start()
96+
existing_workers = await cluster_manager.watch_workers(
97+
get_existing_first=True)
98+
assert set([worker.worker_id for worker in existing_workers]) == {
99+
ctx_worker.worker_id,
100+
gen_worker.worker_id,
101+
}
102+
103+
assert await cluster_manager.is_ready() == True
104+
finally:
105+
await ctx_worker.deregister_worker()
106+
await gen_worker.deregister_worker()
107+
108+
109+
@pytest.mark.parametrize("config", storage_types, indirect=True)
110+
@pytest.mark.threadleak(enabled=False)
111+
@pytest.mark.timeout(20)
112+
@pytest.mark.asyncio(scope="module")
113+
async def test_cluster_manager(cluster_manager, storage_client, config):
114+
try:
115+
cluster_manager.current_ctx_worker_num == 0
116+
cluster_manager.current_gen_worker_num == 0
117+
await cluster_manager.watch_workers()
118+
try:
119+
await asyncio.wait_for(cluster_manager.get_worker_events(),
120+
timeout=1)
121+
except asyncio.TimeoutError:
122+
pass
123+
assert await cluster_manager.is_ready() == False
124+
125+
ctx_worker = ClusterWorker(ServerRole.CONTEXT, "127.0.0.1", 8001,
126+
config, storage_client)
127+
await cluster_manager.watch_workers()
128+
await ctx_worker.register_worker()
129+
worker_events = await cluster_manager.get_worker_events()
130+
assert worker_events == [(ctx_worker.worker_info, WatchEventType.SET)]
131+
assert cluster_manager.current_ctx_worker_num == 1
132+
assert cluster_manager.current_gen_worker_num == 0
133+
assert await cluster_manager.is_ready() == False
134+
135+
gen_worker = ClusterWorker(ServerRole.GENERATION, "127.0.0.1", 8002,
136+
config, storage_client)
137+
await gen_worker.register_worker()
138+
worker_events = await cluster_manager.get_worker_events()
139+
assert worker_events == [(gen_worker.worker_info, WatchEventType.SET)]
140+
assert cluster_manager.current_ctx_worker_num == 1
141+
assert cluster_manager.current_gen_worker_num == 1
142+
assert await cluster_manager.is_ready() == True
143+
144+
await ctx_worker.deregister_worker()
145+
worker_events = await cluster_manager.get_worker_events()
146+
assert worker_events == [(ctx_worker.worker_info, WatchEventType.DELETE)
147+
]
148+
assert cluster_manager.current_ctx_worker_num == 0
149+
assert cluster_manager.current_gen_worker_num == 1
150+
assert await cluster_manager.is_ready() == False
151+
152+
await gen_worker.deregister_worker()
153+
worker_events = await cluster_manager.get_worker_events()
154+
assert worker_events == [(gen_worker.worker_info, WatchEventType.DELETE)
155+
]
156+
assert cluster_manager.current_ctx_worker_num == 0
157+
assert cluster_manager.current_gen_worker_num == 0
158+
assert await cluster_manager.is_ready() == False
159+
finally:
160+
await ctx_worker.deregister_worker()
161+
await gen_worker.deregister_worker()
162+
163+
164+
@pytest.mark.timeout(20)
165+
@pytest.mark.parametrize("config", storage_types, indirect=True)
166+
@pytest.mark.threadleak(enabled=False)
167+
@pytest.mark.asyncio(scope="module")
168+
async def test_cluster_worker(cluster_manager, storage_client, config):
169+
170+
async def wait_for_worker_events(expected_new_event_num,
171+
expected_dead_event_num):
172+
new_worker_ids = []
173+
dead_workers_ids = []
174+
while len(new_worker_ids) < expected_new_event_num or len(
175+
dead_workers_ids) < expected_dead_event_num:
176+
try:
177+
worker_events = await asyncio.wait_for(
178+
cluster_manager.get_worker_events(), timeout=2)
179+
new_workers = [
180+
worker_info.worker_id
181+
for worker_info, event_type in worker_events
182+
if event_type == WatchEventType.SET
183+
]
184+
dead_workers = [
185+
worker_info.worker_id
186+
for worker_info, event_type in worker_events
187+
if event_type == WatchEventType.DELETE
188+
]
189+
print(f"Worker events: {worker_events} {time.time()}")
190+
new_worker_ids += new_workers
191+
dead_workers_ids += dead_workers
192+
except asyncio.TimeoutError:
193+
pass
194+
return new_worker_ids, dead_workers_ids
195+
196+
try:
197+
await cluster_manager.start()
198+
await cluster_manager.watch_workers()
199+
ctx_worker = ClusterWorker(ServerRole.CONTEXT, "127.0.0.1", 8001,
200+
config, storage_client)
201+
gen_worker = ClusterWorker(ServerRole.GENERATION, "127.0.0.1", 8002,
202+
config, storage_client)
203+
204+
keep_heartbeat = True
205+
assert await ctx_worker.register_worker(validator=lambda: keep_heartbeat
206+
)
207+
assert await gen_worker.register_worker(validator=lambda: keep_heartbeat
208+
)
209+
worker_ids = set([ctx_worker.worker_id, gen_worker.worker_id])
210+
new_worker_ids, dead_workers_ids = await wait_for_worker_events(2, 0)
211+
assert set(new_worker_ids) == worker_ids
212+
assert len(dead_workers_ids) == 0
213+
assert await cluster_manager.is_ready() == True
214+
215+
await asyncio.sleep(config.inactive_timeout + 1)
216+
assert await cluster_manager.is_ready() == True
217+
218+
# stop heartbeat, then we should see two workers deleted
219+
keep_heartbeat = False
220+
new_worker_ids, dead_workers_ids = await wait_for_worker_events(0, 2)
221+
assert len(new_worker_ids) == 0
222+
assert len(dead_workers_ids) == 2
223+
assert set(dead_workers_ids) == worker_ids
224+
assert await cluster_manager.is_ready() == False
225+
finally:
226+
await ctx_worker.deregister_worker()
227+
await gen_worker.deregister_worker()

tests/unittest/serve/test_cluster_storage.py renamed to tests/unittest/disaggregated/test_cluster_storage.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,22 @@ async def test_expire(self, storage_server, storage_client):
9696

9797
@timeout(5)
9898
@pytest_async_module
99-
async def test_get_prefix(self, storage_server, storage_client):
99+
async def test_get_keys(self, storage_server, storage_client):
100100
keys = [gen_key("test_key_unique") for _ in range(3)]
101-
for key in keys:
101+
values = [f"test_value{i}" for i in range(3)]
102+
for key, value in zip(keys, values):
102103
assert await storage_client.set(key,
103-
"test_value1",
104+
value,
104105
overwrite_if_exists=True)
105106

106-
answer_keys = await storage_client.get_prefix("test_key_unique")
107-
assert set(keys) == set(answer_keys)
108-
answer_keys = await storage_client.get_prefix(keys[0])
109-
assert answer_keys == [keys[0]]
110-
answer_keys = await storage_client.get_prefix(keys[1])
111-
assert answer_keys == [keys[1]]
107+
answer_keys = await storage_client.get_prefix("test_key_unique",
108+
keys_only=False)
109+
assert set(keys) == set(answer_keys.keys())
110+
assert set(values) == set(answer_keys.values())
111+
answer_keys = await storage_client.get_prefix(keys[0], keys_only=True)
112+
assert answer_keys == {keys[0]: ""}
113+
answer_keys = await storage_client.get_prefix(keys[1], keys_only=True)
114+
assert answer_keys == {keys[1]: ""}
112115

113116
@pytest_ignore_tleak
114117
@pytest_async_module
@@ -199,7 +202,8 @@ def storage_server(self):
199202

200203

201204
class TestEtcdClusterStorage(TestClusterStorage):
202-
__test__ = True
205+
# Disable this test until Etcd functionality is ready.
206+
__test__ = False
203207

204208
@pytest.fixture(scope="class")
205209
def storage_server(self):

tests/unittest/serve/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)