-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Try rpc replace ray #9431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Try rpc replace ray #9431
Conversation
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> unwaive rpc tests simplify RPCServer shutdown Remove pending requests processing, shutdown immediately fix streaming cancelled share event_loop between proxy and client refactor RpcClient by unifying event_loop Simplify. refactor RPCServer by simpify add correctness tests fix worker refactor test_rpc_worker Focus on testing the RpcWorker APIs fix test_rpc_proxy.py restore RPCClient with a dedicated background thread The test_rpc_proxy.py tp1[1] passed fix test_rpc_proxy.py restore RPCClient with a dedicated background thread The test_rpc_proxy.py tp1[1] passed add threaded remote_call test add more debugging print dedicated thread for fetch_responses random hang with submit failed cleanup test_rpc.py fix race condition in zmq socket socket is used in both event_loop in two threads, unify the sending in the rpc_client's main loop thread add ipc TLLM_LLMAPI_ZMQ_DEBUG fix wait_for lost message test passed the race condition is resolved completely refine the pr add test_ipc.py fix tests
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
4a944bf to
0c1d4d3
Compare
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
0c1d4d3 to
d4ddb84
Compare
📝 WalkthroughWalkthroughMajor refactoring consolidates RPC-based execution paths while removing Ray-queue abstractions. Introduces RpcTorchDistExecutor for distributed torch workers with RPC coordination. Enhances RPC infrastructure with improved async/sync lifecycle management, thread safety in ZeroMQ IPC, and dedicated mixin classes for RPC executors and workers. Changes
Sequence Diagram(s)sequenceDiagram
participant Main as Main Process
participant Executor as RpcTorchDistExecutor
participant Worker0 as Rank 0 Worker
participant WorkerN as Rank N Worker
participant ControlGroup as Gloo Control Group
Main->>Executor: create()
Executor->>Executor: init_rpc_executor()
Executor->>Executor: start_workers(world_size)
loop Each Worker
Executor->>Worker0: spawn RpcTorchDistWorker (rank=0)
Executor->>WorkerN: spawn RpcTorchDistWorker (rank=N)
end
Worker0->>Worker0: init_rpc_worker(rpc_addr)
Worker0->>Worker0: start_rpc_server()
WorkerN->>WorkerN: init_rpc_worker(rpc_addr)
Worker0->>ControlGroup: broadcast "setup_engine"
WorkerN->>ControlGroup: receive setup_engine
Worker0->>Worker0: setup_engine()
WorkerN->>WorkerN: setup_engine()
ControlGroup->>ControlGroup: barrier sync
Main->>Executor: submit(request)
Executor->>Executor: submit via RPC to rank 0
Worker0->>Worker0: execute request
Main->>Executor: shutdown()
Executor->>ControlGroup: broadcast "shutdown"
WorkerN->>ControlGroup: receive shutdown
Worker0->>Worker0: cleanup
WorkerN->>WorkerN: cleanup
Executor->>Executor: cleanup
sequenceDiagram
participant Client as RPCClient
participant EventLoop as Event Loop Thread
participant Server as RPCServer
participant Socket as ZMQ Socket
Client->>Client: __init__
Client->>Client: setup_lazily()
Client->>Socket: connect()
Client->>EventLoop: start event loop
Client->>EventLoop: start_response_reader_eagerly()
activate EventLoop
EventLoop->>Socket: poll for responses
rect rgb(200, 230, 255)
Note over EventLoop: Response Reader Loop (async)
loop until shutdown
EventLoop->>Socket: receive response
EventLoop->>Client: update result queue
end
end
Client->>Client: submit(request)
Note over Client: Non-blocking submit via RPC
Server->>Socket: receive request
Server->>Server: process_request()
Server->>Socket: send response
EventLoop->>Client: route response to result
Client->>Client: shutdown()
EventLoop->>EventLoop: cancel tasks
deactivate EventLoop
Client->>Socket: close()
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings, 1 inconclusive)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/llmapi/utils.py (1)
1-1: Add required NVIDIA copyright header at top of filePer the TensorRT‑LLM guidelines, all
.pysources should start with an NVIDIA copyright header including the current year. This file currently starts directly with imports; please prepend the standard project header used elsewhere in the repo.tests/unittest/executor/test_rpc.py (1)
1-10: Missing NVIDIA copyright header.As per coding guidelines, test files should also contain the NVIDIA copyright header.
🧹 Nitpick comments (19)
tests/unittest/llmapi/test_llm_pytorch.py (1)
933-934:num_requestsis unused intest_llm_rpc(Ruff ARG001).The new parametrization currently has no effect on behavior: the test body ignores
num_requests, so it just runs the same single-request check three times and triggers an unused-argument warning.Consider either:
- Using
num_requeststo actually vary the test (e.g., loopllm.generate(...)num_requeststimes and assert each result), or- If the goal is merely to re-run the same scenario multiple times, dropping the parametrization or renaming the argument to
_num_requestsand updating the decorator accordingly to silence the lint.For example, to keep the parametrization but use it:
-@skip_ray -@pytest.mark.parametrize("num_requests", [1, 5, 10]) -def test_llm_rpc(num_requests: int): +@skip_ray +@pytest.mark.parametrize("num_requests", [1, 5, 10]) +def test_llm_rpc(num_requests: int): @@ - res = llm.generate("Tell me a joke", - sampling_params=SamplingParams(max_tokens=10, - end_id=-1)) - print(f"get result: {res}") - - assert len(res.outputs) == 1 - assert len(res.outputs[0].token_ids) == 10 + for _ in range(num_requests): + res = llm.generate("Tell me a joke", + sampling_params=SamplingParams( + max_tokens=10, end_id=-1)) + print(f"get result: {res}") + + assert len(res.outputs) == 1 + assert len(res.outputs[0].token_ids) == 10tensorrt_llm/llmapi/utils.py (1)
49-60: Cyan color addition is fine; consider defensive handling for unknown colorsThe new
cyanentry is consistent with the existing ANSI palette and integrates cleanly withprint_colored. If there’s any chancecolorcomes from unvalidated input, you may want to guard against unknown keys (to avoidKeyError) by usingcolors.get(color, "")or an explicit check before indexing; otherwise this is good as‑is.tensorrt_llm/executor/ipc.py (1)
246-286: Useraise ... from Nonefor the chained TimeoutError.When re-raising
asyncio.TimeoutErrorafter catchingzmq.Again, usefrom Noneto suppress the exception chain and clarify that this is an intentional conversion, not an error during exception handling.Apply this diff:
except zmq.Again: # No message available yet if asyncio.get_event_loop().time() >= deadline: - raise asyncio.TimeoutError() + raise asyncio.TimeoutError() from None # Short sleep to avoid busy-waiting await asyncio.sleep(0.01)tensorrt_llm/executor/rpc_torch_dist_worker.py (1)
119-142: Fix unused variable and improve exception handling.The
argsvariable is unpacked but never used. Also, when re-raising the exception at line 146, use bareraiseto preserve the original traceback.Apply this diff:
- cmd, args = cmd_list + cmd, _args = cmd_list # logger.debug(f"Rank {rank} received command: {cmd}") if cmd == "setup_engine": worker.setup_engine() elif cmd == "shutdown": worker.shutdown() break elif cmd == "report_device_id": # Optional: handle other commands if needed pass else: logger.warning(f"Rank {rank} received unknown command: {cmd}") except Exception as e: logger.error(f"Worker {rank} failed with error: {e}") - raise e + raisetensorrt_llm/executor/rpc/rpc_server.py (2)
256-260: Rename unused loop variable.The loop variable
iis not used within the loop body. Rename it to_to indicate it's intentionally unused.Apply this diff:
# Create worker tasks - for i in range(self._num_workers): + for _ in range(self._num_workers): task = asyncio.create_task(self._process_requests()) self._worker_tasks.append(task)
727-728: Consider removing redundant import.The
tracebackmodule is already imported at the top of the file (line 6). The local import insiderun_loopis redundant.Remove the redundant import:
else: # This is an unexpected RuntimeError - log full details - import traceback logger.error(f"Event loop error: {error_str}") logger.error(f"Traceback: {traceback.format_exc()}")tensorrt_llm/executor/ray_gpu_worker.py (1)
170-195: Makerpc_addrnon-optional in the API and clarify the error message
rpc_addris now required (you raise onNoneand always start the RPC server), but the signature still marks it asOptional[str] = None, and the error text talks about "RPC mode enabled" even though RPC is unconditional.Consider tightening and simplifying:
- rpc_addr: Optional[str] = None, + rpc_addr: str, @@ - if rpc_addr is None: - raise RuntimeError( - "RPC mode enabled but no rpc_addr provided to RayGPUWorker") - self.init_rpc_worker(self.global_rank, rpc_addr) + if not rpc_addr: + raise RuntimeError("rpc_addr must be provided to RayGPUWorker") + self.init_rpc_worker(self.global_rank, rpc_addr)This better reflects the new requirements and avoids a dead
Nonedefault.tests/unittest/executor/test_rpc_proxy.py (1)
25-47: Minor mismatch between helper usage andworker_kwargs
create_fake_executor_configreturns(llm_args, executor_config), but onlyllm_argsis used andworker_kwargs["executor_config"]is hardcoded toNonewhile the comment says "Create executor config with the correct tp_size".Not a functional problem, but for clarity you could either:
- actually pass the returned
executor_config(if/when needed), or- drop
executor_configfrom the helper’s return and fromworker_kwargshere if it’s intentionally unused.tensorrt_llm/_torch/distributed/communicator.py (1)
479-505: Avoid silently swallowing all exceptions in Ray detectionIn
_get_cluster_info, the Ray check is:is_ray_initialized = False try: if ray.is_initialized(): is_ray_initialized = True except Exception: passCatching and ignoring all
Exceptiontypes can mask real Ray/environment bugs and makes debugging harder.Consider narrowing and/or logging, e.g.:
- is_ray_initialized = False - try: - if ray.is_initialized(): - is_ray_initialized = True - except Exception: - pass + is_ray_initialized = False + try: + if hasattr(ray, "is_initialized") and ray.is_initialized(): + is_ray_initialized = True + except (RuntimeError, ValueError): + # Ray present but not usable; fall back to non-Ray path + logger.debug("Ray detected but not initialized; using non-Ray cluster info")This still keeps the safe fallback while avoiding a blanket
except Exception: pass.tests/unittest/llmapi/test_rpc_torch_dist.py (1)
24-27: Minor cleanup: unused temp dir and skip decorator
self.model_diris created insetUpand removed intearDownbut never used to store anything; you can drop it unless you plan to use it for temporary artifacts.self.skip_if_no_gpuis a decorator object that is never applied; GPU gating is already correctly handled via the@unittest.skipIfdecorators on the test methods.These are minor, but cleaning them up will make the test intent clearer.
Also applies to: 55-58
tests/unittest/executor/test_rpc_worker.py (1)
43-56: Tests use fixed sleep which may cause flakiness.Using
time.sleep(0.5)for synchronization is fragile. Consider polling with a timeout instead to make tests more reliable across different environments. Additionally, the assertionisinstance(responses, list)is quite weak—consider verifying the response contains actual data related to the submitted request.Consider a polling pattern:
- # Sleep a bit to let the request start processing - time.sleep(0.5) - - # Fetch responses with a timeout to prevent hanging - responses = asyncio.run(self.worker.fetch_responses_async(timeout=1.0)) - assert isinstance(responses, list) + # Poll for responses with retry + max_attempts = 10 + responses = [] + for _ in range(max_attempts): + responses = asyncio.run(self.worker.fetch_responses_async(timeout=0.5)) + if responses: + break + time.sleep(0.1) + assert isinstance(responses, list)tensorrt_llm/executor/rpc/rpc_client.py (1)
140-152: Initialization uses fixed sleeps for synchronization.The eager initialization at lines 144-152 calls
setup_lazily()and_start_response_reader_eagerly(). Within_start_response_reader_eagerly()(line 438), there's atime.sleep(0.2), and_ensure_event_loop()(line 535) also hastime.sleep(0.2). While these help avoid race conditions, fixed sleeps may be insufficient on heavily loaded systems.Consider using an event or condition to signal when components are ready, rather than fixed delays.
tensorrt_llm/executor/rpc_proxy.py (1)
14-36: INSTANCE_COUNTER increment is not thread-safe.The
INSTANCE_COUNTER += 1operation on a class variable is not atomic in Python. If multiple threads create instances simultaneously, you could get incorrect counts.If accurate counting is needed, consider using a lock or
itertools.count():+import itertools + class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor): - # NOTE: this is a global counter for the number of instances of this class - INSTANCE_COUNTER = 0 + # NOTE: this is a global counter for the number of instances of this class + _instance_counter = itertools.count() def __init__( ... ): - GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1 + self._instance_id = next(GenerationExecutorRpcProxy._instance_counter)tests/unittest/executor/test_rpc.py (1)
884-965: Avoid daemon threads in test code.Using
daemon=True(line 939) for threads in tests can cause issues if the test fails or exits early—daemon threads are abruptly terminated without cleanup. Since you're already callingthread.join(timeout=30), the daemon flag is unnecessary.for i in range(num_threads): thread = threading.Thread(target=remote_caller, - args=(i, ), - daemon=True) + args=(i, )) threads.append(thread)tensorrt_llm/executor/rpc_worker_mixin.py (2)
101-107: Unusedtimeoutparameter.The
timeoutparameter in bothfetch_stats_asyncandfetch_kv_cache_events_asyncis declared but never used. Either remove it or pass it to the underlying calls.- async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: + async def fetch_stats_async(self) -> list: """Async version of fetch_stats using asyncio.to_thread.""" return await asyncio.to_thread(self.fetch_stats) - async def fetch_kv_cache_events_async(self, timeout: Optional[float] = None) -> list: + async def fetch_kv_cache_events_async(self) -> list: """Async version of fetch_kv_cache_events using asyncio.to_thread.""" return await asyncio.to_thread(self.fetch_kv_cache_events)
133-151: Consider sleep after fetch instead of before.The current pattern sleeps before fetching, which adds latency on the first iteration. Consider moving the sleep after processing:
async def _generic_fetch_loop_async( self, fetch_method, serializer, method_name: str, timeout: Optional[float] = None ) -> AsyncGenerator[list, None]: + timeout = timeout or 0.1 while not self.shutdown_event.is_set(): - timeout = timeout or 0.1 - await asyncio.sleep(timeout) data = await fetch_method() yield [serializer(item) for item in data] + await asyncio.sleep(timeout)tensorrt_llm/executor/rpc_torch_dist_executor.py (2)
43-59: Minor formatting issue in log message and use bareraise.
- Missing space between worker count and "Master"
- Use bare
raiseinstead ofraise eto preserve the full tracebacklogger.info( f"RpcTorchDistExecutor starting with {model_world_size} workers." - f"Master: {self.master_addr}:{self.master_port}" + f" Master: {self.master_addr}:{self.master_port}" ) ... except Exception as e: logger.error(f"Failed to setup remote engine: {e}") self.shutdown() - raise e + raise
117-150: Consider documenting intentionally unused parameters.Several parameters (
mpi_session,reuse_mpi_comm,return_logits,kwargs) are accepted but not used. If these are for API compatibility with other executorcreatemethods, consider adding a comment or using explicit_prefix to indicate they're intentionally ignored.@classmethod def create( cls, engine: Union[Path, Engine], executor_config: Optional[tllm.ExecutorConfig] = None, batched_logits_processor: Optional[BatchedLogitsProcessor] = None, model_world_size: int = 1, - mpi_session: Optional[Any] = None, - reuse_mpi_comm: bool = False, - return_logits: bool = False, + mpi_session: Optional[Any] = None, # Unused, for API compatibility + reuse_mpi_comm: bool = False, # Unused, for API compatibility + return_logits: bool = False, # Unused, for API compatibility postproc_worker_config: Optional[PostprocWorkerConfig] = None, ...tensorrt_llm/executor/rpc_proxy_mixin.py (1)
196-264: Use bareraiseinstead ofraise e.Per Python best practices, use bare
raiseto preserve the full traceback chain.except Exception as e: logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}") - raise e + raise
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (29)
tensorrt_llm/_torch/distributed/communicator.py(3 hunks)tensorrt_llm/_utils.py(0 hunks)tensorrt_llm/executor/executor.py(0 hunks)tensorrt_llm/executor/ipc.py(9 hunks)tensorrt_llm/executor/ray_executor.py(4 hunks)tensorrt_llm/executor/ray_gpu_worker.py(2 hunks)tensorrt_llm/executor/result.py(3 hunks)tensorrt_llm/executor/rpc/rpc_client.py(16 hunks)tensorrt_llm/executor/rpc/rpc_common.py(3 hunks)tensorrt_llm/executor/rpc/rpc_server.py(10 hunks)tensorrt_llm/executor/rpc_proxy.py(4 hunks)tensorrt_llm/executor/rpc_proxy_mixin.py(1 hunks)tensorrt_llm/executor/rpc_torch_dist_executor.py(1 hunks)tensorrt_llm/executor/rpc_torch_dist_worker.py(1 hunks)tensorrt_llm/executor/rpc_worker.py(7 hunks)tensorrt_llm/executor/rpc_worker_mixin.py(1 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/utils.py(2 hunks)tests/integration/defs/examples/test_ray.py(1 hunks)tests/integration/test_lists/test-db/l0_h100.yml(1 hunks)tests/integration/test_lists/waives.txt(0 hunks)tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py(0 hunks)tests/unittest/executor/test_base_worker.py(1 hunks)tests/unittest/executor/test_ipc.py(1 hunks)tests/unittest/executor/test_rpc.py(21 hunks)tests/unittest/executor/test_rpc_proxy.py(4 hunks)tests/unittest/executor/test_rpc_worker.py(2 hunks)tests/unittest/llmapi/test_llm_pytorch.py(1 hunks)tests/unittest/llmapi/test_rpc_torch_dist.py(1 hunks)
💤 Files with no reviewable changes (4)
- tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py
- tensorrt_llm/_utils.py
- tensorrt_llm/executor/executor.py
- tests/integration/test_lists/waives.txt
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used (e.g., usefrom package.subpackage import fooand thenfoo.SomeClass()instead offrom package.subpackage.foo import SomeClass)
Python filenames should use snake_case (e.g.,some_file.py)
Python class names should use PascalCase (e.g.,class SomeClass)
Python function and method names should use snake_case (e.g.,def my_awesome_function():)
Python local variable names should use snake_case, with prefixkfor variable names that start with a number (e.g.,k_99th_percentile = ...)
Python global variables should use upper snake_case with prefixG(e.g.,G_MY_GLOBAL = ...)
Python constants should use upper snake_case (e.g.,MY_CONSTANT = ...)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description (e.g.,self.x = 5followed by"""<type>: Description of 'x'""")
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of specific errors possible instead of catching all exceptions
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block to implement the logic
Files:
tensorrt_llm/llmapi/llm_args.pytests/unittest/llmapi/test_rpc_torch_dist.pytests/unittest/executor/test_base_worker.pytensorrt_llm/executor/ray_gpu_worker.pytensorrt_llm/executor/rpc/rpc_client.pytensorrt_llm/executor/result.pytensorrt_llm/_torch/distributed/communicator.pytests/unittest/executor/test_rpc.pytests/unittest/executor/test_ipc.pytensorrt_llm/executor/rpc_worker.pytensorrt_llm/executor/ipc.pytests/integration/defs/examples/test_ray.pytests/unittest/executor/test_rpc_worker.pytensorrt_llm/executor/rpc_proxy.pytensorrt_llm/executor/rpc_torch_dist_worker.pytensorrt_llm/executor/rpc/rpc_common.pytensorrt_llm/executor/rpc/rpc_server.pytests/unittest/llmapi/test_llm_pytorch.pytensorrt_llm/executor/rpc_proxy_mixin.pytensorrt_llm/executor/ray_executor.pytensorrt_llm/llmapi/utils.pytensorrt_llm/executor/rpc_worker_mixin.pytensorrt_llm/executor/rpc_torch_dist_executor.pytests/unittest/executor/test_rpc_proxy.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top
Files:
tensorrt_llm/llmapi/llm_args.pytests/unittest/llmapi/test_rpc_torch_dist.pytests/unittest/executor/test_base_worker.pytensorrt_llm/executor/ray_gpu_worker.pytensorrt_llm/executor/rpc/rpc_client.pytensorrt_llm/executor/result.pytensorrt_llm/_torch/distributed/communicator.pytests/unittest/executor/test_rpc.pytests/unittest/executor/test_ipc.pytensorrt_llm/executor/rpc_worker.pytensorrt_llm/executor/ipc.pytests/integration/defs/examples/test_ray.pytests/unittest/executor/test_rpc_worker.pytensorrt_llm/executor/rpc_proxy.pytensorrt_llm/executor/rpc_torch_dist_worker.pytensorrt_llm/executor/rpc/rpc_common.pytensorrt_llm/executor/rpc/rpc_server.pytests/unittest/llmapi/test_llm_pytorch.pytensorrt_llm/executor/rpc_proxy_mixin.pytensorrt_llm/executor/ray_executor.pytensorrt_llm/llmapi/utils.pytensorrt_llm/executor/rpc_worker_mixin.pytensorrt_llm/executor/rpc_torch_dist_executor.pytests/unittest/executor/test_rpc_proxy.py
🧠 Learnings (14)
📓 Common learnings
Learnt from: venkywonka
Repo: NVIDIA/TensorRT-LLM PR: 6029
File: .github/pull_request_template.md:45-53
Timestamp: 2025-08-27T17:50:13.264Z
Learning: For PR templates in TensorRT-LLM, avoid suggesting changes that would increase developer overhead, such as converting plain bullets to mandatory checkboxes. The team prefers guidance-style bullets that don't require explicit interaction to reduce friction in the PR creation process.
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.
Applied to files:
tensorrt_llm/llmapi/llm_args.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/llmapi/test_rpc_torch_dist.pytests/integration/test_lists/test-db/l0_h100.ymltests/integration/defs/examples/test_ray.pytests/unittest/executor/test_rpc_worker.py
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.
Applied to files:
tests/unittest/llmapi/test_rpc_torch_dist.pytests/integration/test_lists/test-db/l0_h100.yml
📚 Learning: 2025-08-29T14:07:45.863Z
Learnt from: EmmaQiaoCh
Repo: NVIDIA/TensorRT-LLM PR: 7370
File: tests/unittest/trt/model_api/test_model_quantization.py:24-27
Timestamp: 2025-08-29T14:07:45.863Z
Learning: In TensorRT-LLM's CI infrastructure, pytest skip markers (pytest.mark.skip) are properly honored even when test files have __main__ blocks that call test functions directly. The testing system correctly skips tests without requiring modifications to the __main__ block execution pattern.
Applied to files:
tests/unittest/llmapi/test_rpc_torch_dist.pytests/integration/test_lists/test-db/l0_h100.ymltests/unittest/executor/test_rpc_worker.py
📚 Learning: 2025-08-26T09:49:04.956Z
Learnt from: pengbowang-nv
Repo: NVIDIA/TensorRT-LLM PR: 7192
File: tests/integration/test_lists/test-db/l0_dgx_b200.yml:56-72
Timestamp: 2025-08-26T09:49:04.956Z
Learning: In TensorRT-LLM test configuration files, the test scheduling system handles wildcard matching with special rules that prevent duplicate test execution even when the same tests appear in multiple yaml files with overlapping GPU wildcards (e.g., "*b200*" and "*gb200*").
Applied to files:
tests/unittest/llmapi/test_rpc_torch_dist.pytests/integration/test_lists/test-db/l0_h100.yml
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tests/unittest/llmapi/test_rpc_torch_dist.py
📚 Learning: 2025-09-24T03:31:28.908Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 7520
File: tensorrt_llm/_torch/pyexecutor/resource_manager.py:605-613
Timestamp: 2025-09-24T03:31:28.908Z
Learning: In TensorRT-LLM Ray orchestrator mode, ProcessGroups are initialized with both Gloo and NCCL backends (e.g., "cuda:nccl,cpu:gloo"), allowing PyTorch distributed to automatically route CPU tensors through Gloo and GPU tensors through NCCL. This eliminates the need for manual device placement when performing allreduce operations on base types.
Applied to files:
tensorrt_llm/executor/ray_gpu_worker.pytensorrt_llm/_torch/distributed/communicator.pytensorrt_llm/executor/rpc_torch_dist_worker.pytensorrt_llm/executor/ray_executor.py
📚 Learning: 2025-09-16T09:30:09.716Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 7763
File: cpp/tensorrt_llm/CMakeLists.txt:297-301
Timestamp: 2025-09-16T09:30:09.716Z
Learning: In the TensorRT-LLM project, NCCL libraries are loaded earlier by PyTorch libraries or the bindings library, so the main shared library doesn't need NCCL paths in its RPATH - the libraries will already be available in the process address space when needed.
Applied to files:
tensorrt_llm/executor/ray_gpu_worker.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/executor/ray_gpu_worker.pytensorrt_llm/executor/result.pytensorrt_llm/executor/ray_executor.py
📚 Learning: 2025-09-17T02:48:52.732Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 7781
File: tests/integration/test_lists/waives.txt:313-313
Timestamp: 2025-09-17T02:48:52.732Z
Learning: In TensorRT-LLM, `tests/integration/test_lists/waives.txt` is specifically for waiving/skipping tests, while other test list files like those in `test-db/` and `qa/` directories are for different test execution contexts (pre-merge, post-merge, QA tests). The same test appearing in both waives.txt and execution list files is intentional - the test is part of test suites but will be skipped due to the waiver.
Applied to files:
tests/integration/test_lists/test-db/l0_h100.yml
📚 Learning: 2025-09-23T15:12:38.312Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device implementation, NCCL version 2.28+ requirements are handled at runtime in the nccl_device/config layer rather than with compile-time guards. This allows the allreduceOp to remain version-agnostic and delegates version compatibility validation to the appropriate lower-level components that can gracefully handle unsupported configurations.
Applied to files:
tensorrt_llm/_torch/distributed/communicator.py
📚 Learning: 2025-10-13T19:45:03.518Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: tests/unittest/_torch/multi_gpu/test_nccl_device.py:138-149
Timestamp: 2025-10-13T19:45:03.518Z
Learning: In test_nccl_device.py, the NCCL device AllReduce implementation compares the entire residual tensor on each rank, unlike the UB implementation which compares per-rank chunks. The residual chunking calculations in the test are intentionally overridden to reflect this design difference.
Applied to files:
tensorrt_llm/_torch/distributed/communicator.py
📚 Learning: 2025-09-17T06:01:01.836Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7785
File: tests/integration/defs/perf/utils.py:321-333
Timestamp: 2025-09-17T06:01:01.836Z
Learning: In test infrastructure code for disaggregated serving tests, prefer logging errors and continuing execution rather than raising exceptions on timeout, to avoid disrupting test cleanup and causing cascading failures.
Applied to files:
tests/unittest/executor/test_rpc.py
🧬 Code graph analysis (15)
tests/unittest/llmapi/test_rpc_torch_dist.py (2)
tensorrt_llm/executor/rpc_torch_dist_executor.py (2)
RpcTorchDistExecutor(18-158)shutdown(88-115)tensorrt_llm/executor/rpc_torch_dist_worker.py (1)
shutdown(49-63)
tensorrt_llm/executor/rpc/rpc_client.py (2)
tensorrt_llm/executor/ipc.py (5)
get(236-239)get(470-471)ZeroMqQueue(22-411)put(157-167)put(462-468)tensorrt_llm/llmapi/utils.py (7)
get(417-447)get(500-517)AsyncQueue(363-447)sync_q(382-383)put(391-393)put(464-470)_SyncQueue(450-517)
tensorrt_llm/_torch/distributed/communicator.py (1)
tensorrt_llm/_torch/auto_deploy/distributed/common.py (2)
is_initialized(104-105)all_gather_object(65-69)
tests/unittest/executor/test_rpc.py (3)
tensorrt_llm/executor/rpc/rpc_common.py (4)
get_unique_ipc_addr(9-16)RPCCancelled(55-59)RPCError(33-48)RPCStreamingError(62-63)tensorrt_llm/executor/rpc/rpc_client.py (4)
remote(52-57)remote_async(59-64)remote_future(66-71)remote_streaming(73-77)tensorrt_llm/executor/rpc/rpc_server.py (3)
address(81-83)bind(91-114)shutdown(116-176)
tests/unittest/executor/test_ipc.py (1)
tensorrt_llm/executor/ipc.py (15)
ZeroMqQueue(22-411)address(474-475)put(157-167)put(462-468)get(236-239)get(470-471)close(288-294)close(486-495)poll(143-155)notify_with_retry(378-411)put_noblock(169-198)put_async(200-219)get_async(241-244)get_async_noblock(246-286)put_async_noblock(221-234)
tensorrt_llm/executor/rpc_worker.py (2)
tensorrt_llm/_utils.py (2)
mpi_comm(493-494)mpi_rank(527-534)tensorrt_llm/executor/rpc_worker_mixin.py (1)
RpcWorkerMixin(12-151)
tensorrt_llm/executor/ipc.py (1)
tensorrt_llm/llmapi/utils.py (3)
get(417-447)get(500-517)logger_debug(106-120)
tests/integration/defs/examples/test_ray.py (1)
tests/integration/defs/conftest.py (1)
llm_venv(702-719)
tensorrt_llm/executor/rpc_torch_dist_worker.py (2)
tensorrt_llm/executor/rpc_worker_mixin.py (3)
RpcWorkerMixin(12-151)init_rpc_worker(28-38)start_rpc_server(40-46)tensorrt_llm/_torch/auto_deploy/distributed/common.py (1)
broadcast_object_list(58-62)
tensorrt_llm/executor/rpc/rpc_server.py (2)
tensorrt_llm/executor/ipc.py (5)
ZeroMqQueue(22-411)address(474-475)get(236-239)get(470-471)put_async(200-219)tensorrt_llm/executor/rpc/rpc_common.py (6)
RPCCancelled(55-59)RPCError(33-48)RPCRequest(67-82)RPCResponse(86-93)RPCStreamingError(62-63)RPCTimeout(51-52)
tensorrt_llm/executor/rpc_proxy_mixin.py (4)
tensorrt_llm/llmapi/utils.py (13)
AsyncQueue(363-447)_SyncQueue(450-517)put_nowait(401-402)put_nowait(472-474)loop(494-495)put(391-393)put(464-470)notify_many(483-491)full(385-386)full(497-498)get(417-447)get(500-517)EventLoopShutdownError(369-370)tensorrt_llm/executor/rpc/rpc_client.py (2)
remote(52-57)remote_streaming(73-77)tensorrt_llm/executor/rpc/rpc_common.py (1)
get_unique_ipc_addr(9-16)tensorrt_llm/executor/utils.py (1)
is_llm_response(149-155)
tensorrt_llm/executor/ray_executor.py (7)
tensorrt_llm/_utils.py (2)
get_free_port(476-479)nvtx_range_debug(894-918)tensorrt_llm/executor/ray_gpu_worker.py (4)
RayGPUWorker(157-301)submit(76-77)shutdown(105-106)shutdown(243-281)tensorrt_llm/executor/request.py (2)
GenerationRequest(85-136)set_id(133-136)tensorrt_llm/executor/rpc_proxy_mixin.py (4)
init_rpc_executor(30-38)setup_mainloop(40-78)_fetch_responses_loop_async(175-180)submit(80-97)tensorrt_llm/executor/rpc_proxy.py (2)
setup_engine_remote(85-86)shutdown(95-132)tensorrt_llm/executor/base_worker.py (2)
submit(594-626)shutdown(628-636)tensorrt_llm/executor/ipc.py (4)
get(236-239)get(470-471)close(288-294)close(486-495)
tensorrt_llm/executor/rpc_worker_mixin.py (5)
tensorrt_llm/_utils.py (1)
nvtx_range_debug(894-918)tensorrt_llm/llmapi/utils.py (3)
logger_debug(106-120)get(417-447)get(500-517)tensorrt_llm/executor/request.py (1)
GenerationRequest(85-136)tensorrt_llm/executor/rpc/rpc_server.py (3)
RPCServer(19-751)bind(91-114)start(673-751)tensorrt_llm/executor/base_worker.py (1)
set_result_queue(304-307)
tensorrt_llm/executor/rpc_torch_dist_executor.py (4)
tensorrt_llm/executor/executor.py (1)
GenerationExecutor(78-584)tensorrt_llm/executor/rpc_proxy_mixin.py (2)
init_rpc_executor(30-38)setup_mainloop(40-78)tensorrt_llm/executor/rpc_torch_dist_worker.py (5)
RpcTorchDistWorker(12-149)shutdown(49-63)worker_main(75-149)start(46-47)setup_engine(35-44)tensorrt_llm/llmapi/llm_args.py (3)
BaseLlmArgs(1619-2353)world_size(459-460)world_size(469-473)
tests/unittest/executor/test_rpc_proxy.py (1)
tensorrt_llm/llmapi/utils.py (1)
logger_debug(106-120)
🪛 Ruff (0.14.5)
tensorrt_llm/executor/ray_gpu_worker.py
192-193: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/executor/rpc/rpc_client.py
88-88: Docstring contains ambiguous ’ (RIGHT SINGLE QUOTATION MARK). Did you mean ``` (GRAVE ACCENT)?
(RUF002)
201-201: Do not catch blind exception: Exception
(BLE001)
342-342: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
342-342: Avoid specifying long messages outside the exception class
(TRY003)
401-401: Do not catch blind exception: Exception
(BLE001)
443-445: Abstract raise to an inner function
(TRY301)
443-445: Avoid specifying long messages outside the exception class
(TRY003)
470-470: Avoid specifying long messages outside the exception class
(TRY003)
576-576: Docstring contains ambiguous ’ (RIGHT SINGLE QUOTATION MARK). Did you mean ``` (GRAVE ACCENT)?
(RUF002)
578-578: Docstring contains ambiguous ’ (RIGHT SINGLE QUOTATION MARK). Did you mean ``` (GRAVE ACCENT)?
(RUF002)
583-583: Avoid specifying long messages outside the exception class
(TRY003)
598-598: Comment contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF003)
598-598: Comment contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF003)
692-692: Do not catch blind exception: Exception
(BLE001)
tensorrt_llm/_torch/distributed/communicator.py
483-484: try-except-pass detected, consider logging the exception
(S110)
483-483: Do not catch blind exception: Exception
(BLE001)
tests/unittest/executor/test_rpc.py
38-38: Local variable server is assigned to but never used
Remove assignment to unused variable server
(F841)
531-531: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
789-789: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
799-799: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
856-856: Do not catch blind exception: Exception
(BLE001)
858-858: Use explicit conversion flag
Replace with conversion flag
(RUF010)
928-928: Do not catch blind exception: Exception
(BLE001)
930-930: Use explicit conversion flag
Replace with conversion flag
(RUF010)
tensorrt_llm/executor/ipc.py
137-141: Avoid specifying long messages outside the exception class
(TRY003)
284-284: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
tensorrt_llm/executor/rpc_torch_dist_worker.py
31-31: Avoid specifying long messages outside the exception class
(TRY003)
57-57: Do not catch blind exception: Exception
(BLE001)
71-71: Do not catch blind exception: Exception
(BLE001)
125-125: Do not catch blind exception: Exception
(BLE001)
130-130: Unpacked variable args is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
146-146: Use raise without specifying exception name
Remove exception name
(TRY201)
tensorrt_llm/executor/rpc/rpc_server.py
243-243: Do not catch blind exception: Exception
(BLE001)
258-258: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
273-273: Do not catch blind exception: Exception
(BLE001)
346-346: Do not catch blind exception: Exception
(BLE001)
479-479: Do not catch blind exception: Exception
(BLE001)
538-539: Abstract raise to an inner function
(TRY301)
538-539: Avoid specifying long messages outside the exception class
(TRY003)
568-569: Abstract raise to an inner function
(TRY301)
568-569: Avoid specifying long messages outside the exception class
(TRY003)
620-620: Do not catch blind exception: Exception
(BLE001)
638-638: Consider moving this statement to an else block
(TRY300)
639-639: Do not catch blind exception: Exception
(BLE001)
667-667: Do not catch blind exception: Exception
(BLE001)
694-694: Do not catch blind exception: Exception
(BLE001)
730-730: Do not catch blind exception: Exception
(BLE001)
tests/unittest/llmapi/test_llm_pytorch.py
934-934: Unused function argument: num_requests
(ARG001)
tensorrt_llm/executor/rpc_proxy_mixin.py
264-264: Use raise without specifying exception name
Remove exception name
(TRY201)
tensorrt_llm/executor/ray_executor.py
217-217: f-string without any placeholders
Remove extraneous f prefix
(F541)
225-225: Do not catch blind exception: Exception
(BLE001)
243-243: Do not catch blind exception: Exception
(BLE001)
249-249: Do not catch blind exception: Exception
(BLE001)
tensorrt_llm/executor/rpc_worker_mixin.py
30-30: Avoid specifying long messages outside the exception class
(TRY003)
101-101: Unused method argument: timeout
(ARG002)
105-105: Unused method argument: timeout
(ARG002)
tensorrt_llm/executor/rpc_torch_dist_executor.py
59-59: Use raise without specifying exception name
Remove exception name
(TRY201)
100-100: Do not catch blind exception: Exception
(BLE001)
124-124: Unused class method argument: mpi_session
(ARG003)
125-125: Unused class method argument: reuse_mpi_comm
(ARG003)
126-126: Unused class method argument: return_logits
(ARG003)
132-132: Unused class method argument: kwargs
(ARG003)
tests/unittest/executor/test_rpc_proxy.py
65-65: f-string without any placeholders
Remove extraneous f prefix
(F541)
|
|
||
| mode_str = "RPC mode" if self.use_rpc else "Ray queue mode" | ||
| logger_debug(f"Shutting down RayExecutor ({mode_str})", color="yellow") | ||
| logger_debug(f"Shutting down RayExecutor", color="yellow") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unnecessary f-string prefix.
The f-string has no placeholders, so the f prefix is unnecessary.
- logger_debug(f"Shutting down RayExecutor", color="yellow")
+ logger_debug("Shutting down RayExecutor", color="yellow")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| logger_debug(f"Shutting down RayExecutor", color="yellow") | |
| logger_debug("Shutting down RayExecutor", color="yellow") |
🧰 Tools
🪛 Ruff (0.14.5)
217-217: f-string without any placeholders
Remove extraneous f prefix
(F541)
🤖 Prompt for AI Agents
In tensorrt_llm/executor/ray_executor.py around line 217, the log call uses an
unnecessary f-string prefix for a string with no interpolation; remove the
leading "f" so the call uses a plain string (e.g., logger_debug("Shutting down
RayExecutor", color="yellow")) to avoid the redundant f-string.
| try: | ||
| import ray | ||
| pass | ||
| except ModuleNotFoundError: | ||
| from tensorrt_llm import ray_stub as ray | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove dead code: empty try-except block.
This try-except block appears to be a remnant from removed Ray imports. Empty pass blocks in try-except serve no purpose.
Remove the dead code:
-try:
- pass
-except ModuleNotFoundError:
- pass📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| import ray | |
| pass | |
| except ModuleNotFoundError: | |
| from tensorrt_llm import ray_stub as ray | |
| pass |
🤖 Prompt for AI Agents
In tensorrt_llm/executor/result.py around lines 16 to 19, remove the dead empty
try/except block (try: pass except ModuleNotFoundError: pass) that was left
after removing Ray imports; simply delete these four lines so there is no no-op
try/except remaining and ensure there are no leftover unused imports or
references that depended on it.
| import asyncio | ||
| import atexit | ||
| import json | ||
| import threading | ||
| from typing import Callable, List, Optional | ||
|
|
||
| from .._utils import nvtx_range_debug | ||
| from ..llmapi.tracer import global_tracer | ||
| from ..llmapi.utils import AsyncQueue, _SyncQueue | ||
| from ..logger import logger | ||
| from .request import GenerationRequest | ||
| from .result import GenerationResult | ||
| from .rpc import RPCClient | ||
| from .rpc.rpc_common import get_unique_ipc_addr | ||
| from .utils import ErrorResponse, is_llm_response | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing NVIDIA copyright header.
As per coding guidelines, this file should contain the NVIDIA copyright header.
🤖 Prompt for AI Agents
In tensorrt_llm/executor/rpc_proxy_mixin.py around lines 1 to 16, the file is
missing the required NVIDIA copyright header; add the standard NVIDIA copyright
and license header block (matching the project's other source files) at the very
top of the file, above all imports, updating the year/owner fields as
appropriate, and ensure formatting and wording exactly match the repository's
header template so lint/CI checks pass.
| import threading | ||
| from typing import Callable, List, Optional | ||
| from typing import Optional | ||
|
|
||
| from .._utils import nvtx_range_debug | ||
| from ..llmapi.mpi_session import MpiPoolSession, MpiSession | ||
| from ..llmapi.tracer import global_tracer | ||
| from ..llmapi.utils import AsyncQueue, _SyncQueue, logger_debug | ||
| from ..llmapi.utils import logger_debug | ||
| from ..logger import logger | ||
| from .executor import GenerationExecutor | ||
| from .postproc_worker import PostprocWorkerConfig | ||
| from .request import GenerationRequest | ||
| from .result import GenerationResult | ||
| from .rpc import RPCClient | ||
| from .rpc.rpc_common import get_unique_ipc_addr | ||
| from .rpc_proxy_mixin import RpcExecutorMixin | ||
| from .rpc_worker import RpcWorker | ||
| from .utils import (ErrorResponse, create_mpi_comm_session, | ||
| get_spawn_proxy_process_env, is_llm_response) | ||
|
|
||
|
|
||
| class RpcExecutorMixin: | ||
| """Mixin for executors that use RPC client for hot path communication. | ||
| Provides: | ||
| - RPC client initialization | ||
| - Response handling loop | ||
| - Main loop thread management | ||
| - Shutdown logic for RPC components | ||
| The inheriting class should call init_rpc_executor() to set up RPC client. | ||
| """ | ||
|
|
||
| def init_rpc_executor(self): | ||
| self.rpc_addr = get_unique_ipc_addr() | ||
| self.rpc_client = RPCClient(self.rpc_addr) | ||
|
|
||
| self._results = {} | ||
| self._shutdown_event = threading.Event() | ||
| self.main_loop_task_obj = None | ||
| self.main_loop = None | ||
| self.main_loop_thread = None | ||
|
|
||
| def setup_mainloop(self, | ||
| tasks: Optional[List[Callable]] = None, | ||
| thread_name: str = "rpc_proxy_main_loop"): | ||
| """Setup main loop thread with custom async tasks. | ||
| Args: | ||
| tasks: List of async coroutine functions to run. | ||
| thread_name: Name for the main loop thread | ||
| """ | ||
| if tasks is None: | ||
| tasks = [ | ||
| self._fetch_responses_loop_async, | ||
| self._fetch_stats_loop_async, | ||
| ] | ||
| # Only add kv_cache_events loop if it's enabled | ||
| if self._iter_kv_events_result: | ||
| tasks.append(self._fetch_kv_cache_events_loop_async) | ||
|
|
||
| async def main_loop_task(): | ||
| await asyncio.gather(*[task() for task in tasks]) | ||
|
|
||
| def _run_main_loop_task(): | ||
| """Local method to run the main loop task.""" | ||
| self.main_loop = asyncio.new_event_loop() | ||
| asyncio.set_event_loop(self.main_loop) | ||
|
|
||
| self.main_loop_task_obj = self.main_loop.create_task( | ||
| main_loop_task()) | ||
| try: | ||
| self.main_loop.run_until_complete(self.main_loop_task_obj) | ||
| except asyncio.CancelledError: | ||
| pass # Task cancellation is expected during shutdown | ||
| finally: | ||
| self.main_loop.close() | ||
|
|
||
| self.main_loop_thread = threading.Thread(target=_run_main_loop_task, | ||
| daemon=True, | ||
| name=thread_name) | ||
| self.main_loop_thread.start() | ||
| atexit.register(self.shutdown) | ||
|
|
||
| def submit(self, request: GenerationRequest) -> GenerationResult: | ||
| request.set_id(self._get_next_client_id()) | ||
| logprob_params = self._get_logprob_params(request) | ||
|
|
||
| # submit is a fire-and-forget operation, don't need to wait for response | ||
| with nvtx_range_debug("RPCExecutor.submit", | ||
| color="green", | ||
| category="Proxy"): | ||
| self.rpc_client.submit(request).remote(need_response=False) | ||
|
|
||
| result = GenerationResult( | ||
| request, | ||
| background_error_handler=self._handle_background_error, | ||
| executor=self, | ||
| disaggregated_params=request.disaggregated_params, | ||
| logprob_params=logprob_params) | ||
| self._results[request.id] = result | ||
|
|
||
| return result | ||
|
|
||
| def handle_responses(self, responses: list[GenerationResult]) -> bool: | ||
| async_queues = [] | ||
| event_loop = None | ||
|
|
||
| def process_res(res: list): | ||
| for r in res: | ||
| client_id = r.client_id | ||
| nonlocal event_loop | ||
| nonlocal async_queues | ||
|
|
||
| if client_id not in self._results: | ||
| logger.warning( | ||
| f"Received response for unknown client_id: {client_id}") | ||
| continue | ||
|
|
||
| queue = self._results[client_id].queue | ||
| if isinstance(queue, _SyncQueue): | ||
| queue.put_nowait(r) | ||
| async_queues.append(queue) | ||
| # all the loops are identical | ||
| event_loop = event_loop or queue.loop | ||
| else: | ||
| queue.put(r) | ||
|
|
||
| if (is_llm_response(r) and r.result.is_final) or isinstance( | ||
| r, ErrorResponse): | ||
| self._results.pop(client_id) | ||
|
|
||
| # Handle the case where responses might not be a list of lists | ||
| if responses and not isinstance(responses[0], list): | ||
| # If responses is a flat list, wrap it | ||
| responses = [responses] | ||
|
|
||
| for res in responses: | ||
| global_tracer().log_instant("RPC.get") | ||
| process_res(res) | ||
|
|
||
| if async_queues: | ||
| _SyncQueue.notify_many(event_loop, async_queues) | ||
|
|
||
| def handle_stats(self, stats): | ||
| """Handle stats received from RPC worker and put them into the stats result queue. | ||
| Args: | ||
| stats: Statistics data from the RPC worker (can be dict, str, or list) | ||
| """ | ||
| self._handle_iteration_data(stats, self._iter_stats_result, "stats") | ||
|
|
||
| def handle_kv_cache_events(self, events): | ||
| """Handle KV cache events received from RPC worker and put them into the events result queue. | ||
| Args: | ||
| events: KV cache events data from the RPC worker (can be dict, str, or list) | ||
| """ | ||
| self._handle_iteration_data(events, self._iter_kv_events_result, | ||
| "kv_cache_events") | ||
|
|
||
| async def _generic_fetch_loop_async(self, fetch_method_name: str, | ||
| handler_method: Callable, | ||
| method_name: str): | ||
| """Generic method for fetching data in a loop from RPC worker. | ||
| Args: | ||
| fetch_method_name: Name of the RPC client method to call | ||
| handler_method: The handler method to call with the fetched data | ||
| method_name: Name of the method for logging | ||
| """ | ||
| try: | ||
| fetch_method = getattr(self.rpc_client, fetch_method_name) | ||
| async for data in fetch_method().remote_streaming(): | ||
| if self._shutdown_event.is_set(): | ||
| return | ||
| handler_method(data) | ||
| except asyncio.CancelledError: | ||
| logger.debug(f"{method_name} task cancelled") | ||
| except Exception as e: | ||
| logger.error(f"Error in {method_name}: {e}") | ||
| raise | ||
|
|
||
| async def _fetch_responses_loop_async(self): | ||
| await self._generic_fetch_loop_async( | ||
| fetch_method_name="fetch_responses_loop_async", | ||
| handler_method=self.handle_responses, | ||
| method_name="_fetch_responses_loop_async") | ||
|
|
||
| async def _fetch_stats_loop_async(self): | ||
| await self._generic_fetch_loop_async( | ||
| fetch_method_name="fetch_stats_loop_async", | ||
| handler_method=self.handle_stats, | ||
| method_name="_fetch_stats_loop_async") | ||
|
|
||
| async def _fetch_kv_cache_events_loop_async(self): | ||
| await self._generic_fetch_loop_async( | ||
| fetch_method_name="fetch_kv_cache_events_loop_async", | ||
| handler_method=self.handle_kv_cache_events, | ||
| method_name="_fetch_kv_cache_events_loop_async") | ||
|
|
||
| def _handle_iteration_data(self, data, result_singleton, data_type: str): | ||
| """Generic method to handle iteration data received from RPC worker. | ||
| Args: | ||
| data: Data from the RPC worker (can be dict, str, or list) | ||
| result_singleton: The iteration result singleton to put data into | ||
| data_type: Type of data for logging (e.g., "stats", "kv_cache_events") | ||
| """ | ||
| # Make sure we have initialized the iteration results | ||
| self._maybe_initialize_iteration_results() | ||
|
|
||
| if not result_singleton: | ||
| logger.debug( | ||
| f"Skipping {data_type} handling while result_singleton=None") | ||
| return | ||
|
|
||
| # Get the queue from the result singleton | ||
| queue = result_singleton.queue | ||
| async_queues = [] | ||
|
|
||
| # Clear old data if queue is full (similar to _iteration_result_task) | ||
| while queue.full(): | ||
| queue.get() | ||
|
|
||
| try: | ||
| # Handle different types of data | ||
| if isinstance(data, str): | ||
| # Already JSON serialized | ||
| data_json = data | ||
| elif isinstance(data, list): | ||
| # Skip empty lists to avoid putting nothing in the queue | ||
| if not data: | ||
| logger.debug( | ||
| f"rpc_proxy.py: Skipping empty {data_type} list") | ||
| return | ||
|
|
||
| # Handle list of data (multiple iterations) | ||
| for item in data: | ||
| if isinstance(item, str): | ||
| item_json = item | ||
| else: | ||
| item_json = json.dumps(item) | ||
|
|
||
| if isinstance(queue, _SyncQueue): | ||
| queue.put_nowait(item_json) | ||
| async_queues.append(queue) | ||
| else: | ||
| queue.put(item_json) | ||
|
|
||
| if async_queues: | ||
| _SyncQueue.notify_many(queue.loop, async_queues) | ||
| return | ||
| else: | ||
| # Convert dict/other to JSON string as expected by IterationResult | ||
| data_json = json.dumps(data) | ||
|
|
||
| if isinstance(queue, _SyncQueue): | ||
| queue.put_nowait(data_json) | ||
| async_queues.append(queue) | ||
| else: | ||
| queue.put(data_json) | ||
|
|
||
| if async_queues: | ||
| _SyncQueue.notify_many(queue.loop, async_queues) | ||
|
|
||
| except AsyncQueue.EventLoopShutdownError: | ||
| # This happens when the event loop is already closed | ||
| logger.debug( | ||
| f"rpc_proxy.py: EventLoopShutdownError in handle_{data_type}") | ||
| except Exception as e: | ||
| logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}") | ||
| raise e | ||
| from .utils import create_mpi_comm_session, get_spawn_proxy_process_env |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing NVIDIA copyright header.
As per coding guidelines, all TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top.
Add the NVIDIA copyright header at the top of the file:
+# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import threading📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import threading | |
| from typing import Callable, List, Optional | |
| from typing import Optional | |
| from .._utils import nvtx_range_debug | |
| from ..llmapi.mpi_session import MpiPoolSession, MpiSession | |
| from ..llmapi.tracer import global_tracer | |
| from ..llmapi.utils import AsyncQueue, _SyncQueue, logger_debug | |
| from ..llmapi.utils import logger_debug | |
| from ..logger import logger | |
| from .executor import GenerationExecutor | |
| from .postproc_worker import PostprocWorkerConfig | |
| from .request import GenerationRequest | |
| from .result import GenerationResult | |
| from .rpc import RPCClient | |
| from .rpc.rpc_common import get_unique_ipc_addr | |
| from .rpc_proxy_mixin import RpcExecutorMixin | |
| from .rpc_worker import RpcWorker | |
| from .utils import (ErrorResponse, create_mpi_comm_session, | |
| get_spawn_proxy_process_env, is_llm_response) | |
| class RpcExecutorMixin: | |
| """Mixin for executors that use RPC client for hot path communication. | |
| Provides: | |
| - RPC client initialization | |
| - Response handling loop | |
| - Main loop thread management | |
| - Shutdown logic for RPC components | |
| The inheriting class should call init_rpc_executor() to set up RPC client. | |
| """ | |
| def init_rpc_executor(self): | |
| self.rpc_addr = get_unique_ipc_addr() | |
| self.rpc_client = RPCClient(self.rpc_addr) | |
| self._results = {} | |
| self._shutdown_event = threading.Event() | |
| self.main_loop_task_obj = None | |
| self.main_loop = None | |
| self.main_loop_thread = None | |
| def setup_mainloop(self, | |
| tasks: Optional[List[Callable]] = None, | |
| thread_name: str = "rpc_proxy_main_loop"): | |
| """Setup main loop thread with custom async tasks. | |
| Args: | |
| tasks: List of async coroutine functions to run. | |
| thread_name: Name for the main loop thread | |
| """ | |
| if tasks is None: | |
| tasks = [ | |
| self._fetch_responses_loop_async, | |
| self._fetch_stats_loop_async, | |
| ] | |
| # Only add kv_cache_events loop if it's enabled | |
| if self._iter_kv_events_result: | |
| tasks.append(self._fetch_kv_cache_events_loop_async) | |
| async def main_loop_task(): | |
| await asyncio.gather(*[task() for task in tasks]) | |
| def _run_main_loop_task(): | |
| """Local method to run the main loop task.""" | |
| self.main_loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(self.main_loop) | |
| self.main_loop_task_obj = self.main_loop.create_task( | |
| main_loop_task()) | |
| try: | |
| self.main_loop.run_until_complete(self.main_loop_task_obj) | |
| except asyncio.CancelledError: | |
| pass # Task cancellation is expected during shutdown | |
| finally: | |
| self.main_loop.close() | |
| self.main_loop_thread = threading.Thread(target=_run_main_loop_task, | |
| daemon=True, | |
| name=thread_name) | |
| self.main_loop_thread.start() | |
| atexit.register(self.shutdown) | |
| def submit(self, request: GenerationRequest) -> GenerationResult: | |
| request.set_id(self._get_next_client_id()) | |
| logprob_params = self._get_logprob_params(request) | |
| # submit is a fire-and-forget operation, don't need to wait for response | |
| with nvtx_range_debug("RPCExecutor.submit", | |
| color="green", | |
| category="Proxy"): | |
| self.rpc_client.submit(request).remote(need_response=False) | |
| result = GenerationResult( | |
| request, | |
| background_error_handler=self._handle_background_error, | |
| executor=self, | |
| disaggregated_params=request.disaggregated_params, | |
| logprob_params=logprob_params) | |
| self._results[request.id] = result | |
| return result | |
| def handle_responses(self, responses: list[GenerationResult]) -> bool: | |
| async_queues = [] | |
| event_loop = None | |
| def process_res(res: list): | |
| for r in res: | |
| client_id = r.client_id | |
| nonlocal event_loop | |
| nonlocal async_queues | |
| if client_id not in self._results: | |
| logger.warning( | |
| f"Received response for unknown client_id: {client_id}") | |
| continue | |
| queue = self._results[client_id].queue | |
| if isinstance(queue, _SyncQueue): | |
| queue.put_nowait(r) | |
| async_queues.append(queue) | |
| # all the loops are identical | |
| event_loop = event_loop or queue.loop | |
| else: | |
| queue.put(r) | |
| if (is_llm_response(r) and r.result.is_final) or isinstance( | |
| r, ErrorResponse): | |
| self._results.pop(client_id) | |
| # Handle the case where responses might not be a list of lists | |
| if responses and not isinstance(responses[0], list): | |
| # If responses is a flat list, wrap it | |
| responses = [responses] | |
| for res in responses: | |
| global_tracer().log_instant("RPC.get") | |
| process_res(res) | |
| if async_queues: | |
| _SyncQueue.notify_many(event_loop, async_queues) | |
| def handle_stats(self, stats): | |
| """Handle stats received from RPC worker and put them into the stats result queue. | |
| Args: | |
| stats: Statistics data from the RPC worker (can be dict, str, or list) | |
| """ | |
| self._handle_iteration_data(stats, self._iter_stats_result, "stats") | |
| def handle_kv_cache_events(self, events): | |
| """Handle KV cache events received from RPC worker and put them into the events result queue. | |
| Args: | |
| events: KV cache events data from the RPC worker (can be dict, str, or list) | |
| """ | |
| self._handle_iteration_data(events, self._iter_kv_events_result, | |
| "kv_cache_events") | |
| async def _generic_fetch_loop_async(self, fetch_method_name: str, | |
| handler_method: Callable, | |
| method_name: str): | |
| """Generic method for fetching data in a loop from RPC worker. | |
| Args: | |
| fetch_method_name: Name of the RPC client method to call | |
| handler_method: The handler method to call with the fetched data | |
| method_name: Name of the method for logging | |
| """ | |
| try: | |
| fetch_method = getattr(self.rpc_client, fetch_method_name) | |
| async for data in fetch_method().remote_streaming(): | |
| if self._shutdown_event.is_set(): | |
| return | |
| handler_method(data) | |
| except asyncio.CancelledError: | |
| logger.debug(f"{method_name} task cancelled") | |
| except Exception as e: | |
| logger.error(f"Error in {method_name}: {e}") | |
| raise | |
| async def _fetch_responses_loop_async(self): | |
| await self._generic_fetch_loop_async( | |
| fetch_method_name="fetch_responses_loop_async", | |
| handler_method=self.handle_responses, | |
| method_name="_fetch_responses_loop_async") | |
| async def _fetch_stats_loop_async(self): | |
| await self._generic_fetch_loop_async( | |
| fetch_method_name="fetch_stats_loop_async", | |
| handler_method=self.handle_stats, | |
| method_name="_fetch_stats_loop_async") | |
| async def _fetch_kv_cache_events_loop_async(self): | |
| await self._generic_fetch_loop_async( | |
| fetch_method_name="fetch_kv_cache_events_loop_async", | |
| handler_method=self.handle_kv_cache_events, | |
| method_name="_fetch_kv_cache_events_loop_async") | |
| def _handle_iteration_data(self, data, result_singleton, data_type: str): | |
| """Generic method to handle iteration data received from RPC worker. | |
| Args: | |
| data: Data from the RPC worker (can be dict, str, or list) | |
| result_singleton: The iteration result singleton to put data into | |
| data_type: Type of data for logging (e.g., "stats", "kv_cache_events") | |
| """ | |
| # Make sure we have initialized the iteration results | |
| self._maybe_initialize_iteration_results() | |
| if not result_singleton: | |
| logger.debug( | |
| f"Skipping {data_type} handling while result_singleton=None") | |
| return | |
| # Get the queue from the result singleton | |
| queue = result_singleton.queue | |
| async_queues = [] | |
| # Clear old data if queue is full (similar to _iteration_result_task) | |
| while queue.full(): | |
| queue.get() | |
| try: | |
| # Handle different types of data | |
| if isinstance(data, str): | |
| # Already JSON serialized | |
| data_json = data | |
| elif isinstance(data, list): | |
| # Skip empty lists to avoid putting nothing in the queue | |
| if not data: | |
| logger.debug( | |
| f"rpc_proxy.py: Skipping empty {data_type} list") | |
| return | |
| # Handle list of data (multiple iterations) | |
| for item in data: | |
| if isinstance(item, str): | |
| item_json = item | |
| else: | |
| item_json = json.dumps(item) | |
| if isinstance(queue, _SyncQueue): | |
| queue.put_nowait(item_json) | |
| async_queues.append(queue) | |
| else: | |
| queue.put(item_json) | |
| if async_queues: | |
| _SyncQueue.notify_many(queue.loop, async_queues) | |
| return | |
| else: | |
| # Convert dict/other to JSON string as expected by IterationResult | |
| data_json = json.dumps(data) | |
| if isinstance(queue, _SyncQueue): | |
| queue.put_nowait(data_json) | |
| async_queues.append(queue) | |
| else: | |
| queue.put(data_json) | |
| if async_queues: | |
| _SyncQueue.notify_many(queue.loop, async_queues) | |
| except AsyncQueue.EventLoopShutdownError: | |
| # This happens when the event loop is already closed | |
| logger.debug( | |
| f"rpc_proxy.py: EventLoopShutdownError in handle_{data_type}") | |
| except Exception as e: | |
| logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}") | |
| raise e | |
| from .utils import create_mpi_comm_session, get_spawn_proxy_process_env | |
| # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import threading | |
| from typing import Optional | |
| from ..llmapi.mpi_session import MpiPoolSession, MpiSession | |
| from ..llmapi.utils import logger_debug | |
| from ..logger import logger | |
| from .executor import GenerationExecutor | |
| from .postproc_worker import PostprocWorkerConfig | |
| from .rpc_proxy_mixin import RpcExecutorMixin | |
| from .rpc_worker import RpcWorker | |
| from .utils import create_mpi_comm_session, get_spawn_proxy_process_env |
🤖 Prompt for AI Agents
In tensorrt_llm/executor/rpc_proxy.py around lines 1 to 11, the file is missing
the required NVIDIA copyright header; insert the project's standard NVIDIA
copyright header at the very top of the file (before any imports), update the
year to 2025, and ensure the header matches the repository template (including
SPDX/license tag if used) so formatting and newline placement remain consistent
with other source files.
| import multiprocessing | ||
| from pathlib import Path | ||
| from typing import Any, Dict, List, Optional, Union | ||
|
|
||
| from tensorrt_llm._utils import get_free_port | ||
| from tensorrt_llm.bindings import executor as tllm | ||
| from tensorrt_llm.builder import Engine | ||
| from tensorrt_llm.executor.executor import GenerationExecutor | ||
| from tensorrt_llm.executor.postproc_worker import PostprocWorkerConfig | ||
| from tensorrt_llm.executor.rpc_proxy_mixin import RpcExecutorMixin | ||
| from tensorrt_llm.executor.rpc_torch_dist_worker import RpcTorchDistWorker | ||
| from tensorrt_llm.llmapi.llm_args import BaseLlmArgs | ||
| from tensorrt_llm.llmapi.tokenizer import TokenizerBase | ||
| from tensorrt_llm.logger import logger | ||
| from tensorrt_llm.sampling_params import BatchedLogitsProcessor | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing NVIDIA copyright header.
As per coding guidelines, this file should contain the NVIDIA copyright header.
🤖 Prompt for AI Agents
In tensorrt_llm/executor/rpc_torch_dist_executor.py around lines 1 to 16, the
file is missing the required NVIDIA copyright header; add the standard NVIDIA
copyright and license header at the top of the file (before any imports),
matching our repository's header format and year/owner details, and ensure line
breaks and encoding are preserved so the header appears as the first block in
the file.
| import os | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from tensorrt_llm.executor.base_worker import BaseWorker | ||
| from tensorrt_llm.executor.rpc_worker_mixin import RpcWorkerMixin | ||
| from tensorrt_llm.logger import logger | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing NVIDIA copyright header.
As per coding guidelines, all TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top.
Add the copyright header at the top of the file:
+# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from typing import Any, Dict, Optional🤖 Prompt for AI Agents
In tensorrt_llm/executor/rpc_torch_dist_worker.py around lines 1 to 10, the file
is missing the required NVIDIA copyright header; add the standard NVIDIA OSS
copyright header (including the current year) at the very top of the file before
any imports, ensuring it matches the project's header format and includes
copyright, license notice, and year.
| import asyncio | ||
| from queue import Queue | ||
| from threading import Event | ||
| from typing import AsyncGenerator, Optional | ||
|
|
||
| from .._utils import nvtx_range_debug | ||
| from ..llmapi.utils import logger_debug | ||
| from .request import GenerationRequest | ||
| from .rpc import RPCServer | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing NVIDIA copyright header.
As per coding guidelines, this file should contain the NVIDIA copyright header.
🤖 Prompt for AI Agents
In tensorrt_llm/executor/rpc_worker_mixin.py around lines 1 to 10, the file is
missing the required NVIDIA copyright header; add the project's standard NVIDIA
copyright header block at the very top of the file (before any imports), filling
in the correct year(s)/owner if required and including the SPDX license
identifier or exact text used in other files in the repo so the header matches
the project template.
| except asyncio.TimeoutError: | ||
| # Check if we should exit due to cancellation | ||
| if self._closed or (self._reader_asyncio_task | ||
| and self._reader_asyncio_task.cancelled()): | ||
| raise asyncio.CancelledError("Reader task cancelled") | ||
| # Otherwise continue polling | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use exception chaining for clarity.
When raising a new exception from within an except block, use from None to suppress the misleading exception chain.
- raise asyncio.CancelledError("Reader task cancelled")
+ raise asyncio.CancelledError("Reader task cancelled") from None🧰 Tools
🪛 Ruff (0.14.5)
342-342: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
342-342: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tensorrt_llm/executor/rpc/rpc_client.py around lines 338 to 344, the except
asyncio.TimeoutError handler raises a new CancelledError inside the except block
which produces a misleading exception chain; change the raise to use exception
chaining suppression by raising asyncio.CancelledError("Reader task cancelled")
from None so the original TimeoutError is not shown as the cause, keeping the
rest of the logic (checking self._closed and reader task cancelled) unchanged.
| orchestrator_type: Optional[Literal[ | ||
| "rpc", "ray", "rpc_torch_dist"]] = Field( | ||
| default=None, | ||
| description= | ||
| "The orchestrator type to use. Defaults to None, which uses MPI.", | ||
| status="prototype", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Extension of orchestrator_type to support rpc_torch_dist.
The addition of "rpc_torch_dist" to the allowed literals aligns with the new RpcTorchDistExecutor introduced in this PR. The field documentation and prototype status are preserved.
Verify that there is validation logic elsewhere that handles the new "rpc_torch_dist" orchestrator type and routes to the appropriate executor.
🏁 Script executed:
#!/bin/bash
# Search for usage of orchestrator_type to verify routing logic exists
rg -n "orchestrator_type.*rpc_torch_dist" --type=py
rg -n "RpcTorchDistExecutor" --type=py -C3Length of output: 3804
---
🏁 Script executed:
```shell
#!/bin/bash
# Search for where orchestrator_type is used in LLM initialization and routing
rg -n "orchestrator_type" tensorrt_llm/llmapi/llm.py -B2 -A2
rg -n "def __init__" tensorrt_llm/llmapi/llm.py -A20 | head -50
Length of output: 2907
🏁 Script executed:
#!/bin/bash
# Look for executor selection/routing logic
rg -n "executor.*orchestrator\|orchestrator.*executor" --type=py -iLength of output: 0
---
🏁 Script executed:
```shell
#!/bin/bash
# Search for where executor is actually instantiated/created in LLM
rg -n "self\._executor\|GenerationExecutor\|executor_cls" tensorrt_llm/llmapi/llm.py -B1 -A3 | head -80
# Look for executor factory or selection logic
rg -n "def.*executor\|create.*executor\|select.*executor" tensorrt_llm/llmapi/ --type=py -i
Length of output: 0
---
🏁 Script executed:
```shell
#!/bin/bash
# Broader search for executor creation and orchestrator_type usage
rg -n "executor_cls" . --type=py -A5 | head -100
rg -n "orchestrator_type" . --type=py | grep -v test | head -40
# Check llm.py directly for executor flow
cat tensorrt_llm/llmapi/llm.py | grep -n "executor" -i | head -30
Length of output: 11386
---
🏁 Script executed:
```shell
#!/bin/bash
# Verify the exact validation logic in executor.py
cat tensorrt_llm/executor/executor.py | sed -n '467,490p'
Length of output: 1254
Update orchestrator_type validation to support "rpc_torch_dist".
The field definition in tensorrt_llm/llmapi/llm_args.py now allows "rpc_torch_dist", but the corresponding validation logic was not updated. The executor creation will fail with ValueError: Unsupported orchestrator_type: rpc_torch_dist at runtime.
Location: tensorrt_llm/executor/executor.py, lines 479-481
The validation currently allows only "ray" or "rpc". Update the condition to also accept "rpc_torch_dist":
elif orchestrator_type is not None and orchestrator_type not in ("rpc", "rpc_torch_dist"):
raise ValueError(
f"Unsupported orchestrator_type: {orchestrator_type}")🤖 Prompt for AI Agents
In tensorrt_llm/executor/executor.py around lines 479-481 the orchestrator_type
validation only accepts "rpc" (and previously "ray"), causing a ValueError for
the newly allowed "rpc_torch_dist"; update the conditional to include
"rpc_torch_dist" as an allowed value (i.e., treat "rpc_torch_dist" the same as
"rpc" in the acceptance check) so the executor creation won't raise Unsupported
orchestrator_type for that string.
| @pytest.mark.parametrize("num_reqs", [1, 5, 10]) | ||
| def test_tp1(self, num_reqs): | ||
| tokenizer = TransformersTokenizer.from_pretrained(model_path) | ||
| prompt = "A B C D" | ||
| prompt_token_ids = tokenizer.encode(prompt) | ||
| max_tokens = 8 | ||
|
|
||
| with self.create_proxy(tp_size=1) as proxy: | ||
| logger_debug(f"[Test] Proxy created", color="green") | ||
| sampling_params = SamplingParams(max_tokens=max_tokens) | ||
| for _ in range(num_reqs): | ||
| logger_debug(f"[Test] Generating {_}th", color="green") | ||
| result = proxy.generate(prompt_token_ids, sampling_params) | ||
| print(f"get result: {result}") | ||
| assert similar(tokenizer.decode(result.outputs[0].token_ids), | ||
| 'E F G H I J K L') | ||
| logger_debug(f"req {_} get result: {result}", color="green") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unnecessary f prefixes in logger_debug calls
logger_debug calls such as on Line 65 use f-strings without any placeholders (f"[Test] Proxy created"), which Ruff flags (F541). They can be plain strings:
- logger_debug(f"[Test] Proxy created", color="green")
+ logger_debug("[Test] Proxy created", color="green")
@@
- logger_debug(f"[Test] Generating {_}th", color="green")
+ logger_debug(f"[Test] Generating {_}th", color="green")
@@
- logger_debug(f"req {_} get result: {result}", color="green")
+ logger_debug(f"req {_} get result: {result}", color="green")
@@
-if __name__ == "__main__":
- TestRpcProxy().test_tp1(20)
+if __name__ == "__main__":
+ TestRpcProxy().test_tp1(20)(Only the first call needs the f removed; the others are already using interpolation correctly.)
This keeps the logs unchanged while satisfying the linter.
Also applies to: 99-100
🧰 Tools
🪛 Ruff (0.14.5)
65-65: f-string without any placeholders
Remove extraneous f prefix
(F541)
🤖 Prompt for AI Agents
tests/unittest/executor/test_rpc_proxy.py lines 57-73: remove the unnecessary
f-string prefixes on logger_debug calls that have no interpolations (e.g.,
change f"[Test] Proxy created" to a plain string "[Test] Proxy created"); apply
the same fix to the similar occurrences called out (around lines 99-100) so each
logger_debug uses a normal string when there are no placeholders.
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.