Skip to content

Commit a61f1ca

Browse files
committed
add test for BaseWorker
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
1 parent 64af73e commit a61f1ca

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import os
2+
import sys
3+
from queue import Queue
4+
5+
# isort: off
6+
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
7+
from utils.llm_data import llm_models_root
8+
from tensorrt_llm.bindings import executor as tllm
9+
# isort: on
10+
11+
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
12+
from tensorrt_llm.executor.request import GenerationRequest
13+
from tensorrt_llm.executor.worker_base import WorkerBase
14+
from tensorrt_llm.llmapi.llm_args import LlmArgs
15+
from tensorrt_llm.sampling_params import SamplingParams
16+
17+
default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
18+
model_path = llm_models_root() / default_model_name
19+
20+
21+
class TestWorkerBase:
22+
23+
def test_create_engine(self):
24+
with WorkerBase(engine=model_path) as worker:
25+
pass
26+
27+
def test_submit_request(self):
28+
sampling_params = SamplingParams(max_tokens=10)
29+
request = GenerationRequest(prompt_token_ids=[3, 4, 5],
30+
sampling_params=sampling_params)
31+
with WorkerBase(engine=model_path) as worker:
32+
worker.submit(request)
33+
34+
def test_await_responses(self):
35+
sampling_params = SamplingParams(max_tokens=10)
36+
request = GenerationRequest(prompt_token_ids=[3, 4, 5],
37+
sampling_params=sampling_params)
38+
with WorkerBase(engine=model_path) as worker:
39+
result_queue = Queue()
40+
worker.set_result_queue(result_queue)
41+
42+
worker.submit(request)
43+
for i in range(10):
44+
worker.await_responses()
45+
46+
assert result_queue.qsize() > 0
47+
48+
def _create_executor_config(self):
49+
llm_args = LlmArgs(model=model_path, cuda_graph_config=None)
50+
51+
executor_config = tllm.ExecutorConfig(1)
52+
executor_config.max_batch_size = 1
53+
54+
update_executor_config(
55+
executor_config,
56+
backend="pytorch",
57+
pytorch_backend_config=llm_args.get_pytorch_backend_config(),
58+
mapping=llm_args.parallel_config.to_mapping(),
59+
speculative_config=llm_args.speculative_config,
60+
hf_model_dir=model_path,
61+
max_input_len=20,
62+
max_seq_len=40,
63+
checkpoint_format=llm_args.checkpoint_format,
64+
checkpoint_loader=llm_args.checkpoint_loader,
65+
)
66+
67+
return executor_config
68+
69+
70+
if __name__ == "__main__":
71+
test_worker_base = TestWorkerBase()
72+
test_worker_base.test_create_engine()

0 commit comments

Comments
 (0)