Skip to content

Commit 566a849

Browse files
committed
fix w/ cuda graph
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent da1275a commit 566a849

File tree

5 files changed

+181
-111
lines changed

5 files changed

+181
-111
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
device: str,
3434
attn_metadata: AttentionMetadata,
3535
spec_metadata: Optional[SpecMetadata] = None,
36+
guided_metadata=None,
3637
use_mrope: bool = False,
3738
max_beam_width: int = 1,
3839
) -> None:
@@ -71,6 +72,7 @@ def __init__(
7172

7273
self.attn_metadata = attn_metadata
7374
self.spec_metadata = spec_metadata
75+
self.guided_metadata = guided_metadata
7476
self._output = None
7577
self._graph = None
7678
self.optional_extra_model_inputs = ["mrope_position_deltas"]
@@ -90,6 +92,7 @@ def capture(
9092
"position_ids": self.position_ids,
9193
"inputs_embeds": None,
9294
"spec_metadata": self.spec_metadata,
95+
"guided_metadata": self.guided_metadata,
9396
"mrope_position_deltas": self.mrope_position_deltas,
9497
}
9598

0 commit comments

Comments
 (0)