Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/llm-api/quickstart_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def parse_arguments():
parser = add_lora_args(parser)
args = parser.parse_args()

args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite
if args.kv_cache_fraction is None:
args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal

Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/models/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from ..model_config import ModelConfig
from .modeling_auto import AutoModelForCausalLM
from .modeling_clip import CLIPVisionModel
from .modeling_multimodal_utils import fuse_input_embeds
from .modeling_multimodal_utils import (find_uncached_mm_embeds,
fuse_input_embeds)
from .modeling_utils import (filter_weights, register_auto_model,
register_vision_encoder)

Expand Down Expand Up @@ -469,6 +470,8 @@ def forward(
]
else:
mm_embeds = self.mm_encoder.forward(multimodal_params)
mm_embeds = find_uncached_mm_embeds(
mm_embeds, multimodal_params[:num_context_requests])
else:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
Expand Down
36 changes: 22 additions & 14 deletions tensorrt_llm/inputs/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ class MultimodalInput:
"""

multimodal_positions: List[int]
"""Starting positions of each multimodal chunk in the token sequence.
"""Starting positions of each contiguous multimodal token chunk in the token sequence.
Contains only the start position of each chunk, not all positions of multimodal tokens.
This is different from mm_positions elsewhere which contains all positions.
"""

multimodal_lengths: List[int]
"""Length (number of tokens) of each multimodal item.
"""Length of each contiguous multimodal token chunk, including any special tokens.
Combined with multimodal_positions, this defines the token spans for each multimodal item.
Each span is unique to its multimodal item and may include special tokens for some models,
(e.g., image_end_token, image_break_token for mistral3) mixed with the actual multimodal tokens.
"""

def __post_init__(self):
Expand Down Expand Up @@ -485,7 +486,13 @@ def hexdigest_to_int32(hex_digest: str) -> List[int]:

def find_mm_token_lengths(mm_data: Dict[str, Any],
input_processor: Any) -> List[int]:
"""Get multimodal token lengths from multimodal data items. """
"""Get the maximum contiguous multimodal token lengths from multimodal data items.
Returns the total token count for each multimodal item, including any special tokens
(e.g., image_begin, image_end, image_break) that may be mixed with the actual
multimodal content tokens. This mm_token_lengths represents the full contiguous chunk from beginning
to end, not just pure image/video/audio tokens.
"""

mm_items = {
modality: items if isinstance(items, list) else [items]
Expand Down Expand Up @@ -528,22 +535,23 @@ def find_mm_token_positions(
num_mm_tokens: List[int],
vocab_size: Optional[int] = None,
mm_token_ids: Optional[torch.Tensor] = None) -> List[int]:
"""Get multimodal token positions using IDs > vocab_size and known lengths.
"""Get starting positions of contiguous multimodal token chunks using known lengths.
This function finds multimodal tokens (with IDs > vocab_size or matching mm_token_ids)
and uses the provided lengths in num_mm_tokens to identify where each contiguous chunk starts.
Each chunk in num_mm_tokens is assumed to be a contiguous block of multimodal tokens for each multimodal item, and may include special tokens (e.g., image_begin, image_end, image_break) within the chunk.
This function finds multimodal tokens (with IDs > vocab_size) and uses the
provided lengths in num_mm_tokens to identify where each chunk starts.
This works even when there are no gaps between different image sequences
(e.g., when all images use the same token IDs).
Note at least one of vocab_size or mm_token_ids must be provided. If mm_token_ids is provided, vocab_size is ignored.
Note: at least one of vocab_size or mm_token_ids must be provided. If mm_token_ids
is provided, vocab_size is ignored.
Args:
input_ids: Token sequence (tensor, list, or numpy array)
num_mm_tokens: List of lengths for each multimodal token chunk
vocab_size: Size of the model's vocabulary
mm_token_ids: Possible token ids for multimodal tokens
num_mm_tokens: List of contiguous chunk lengths for each multimodal item
vocab_size: Size of the model's vocabulary (used to identify tokens > vocab_size)
mm_token_ids: Specific token IDs that represent multimodal tokens
Returns:
List of starting positions for each multimodal token chunk
List of starting positions for each contiguous multimodal token chunk
"""
if mm_token_ids is None and vocab_size is None:
raise ValueError(
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def get_vocab_size(self) -> Optional[int]:

def get_mm_token_ids(self) -> Optional[Tensor]:
"""Return multimodal token IDs if available; otherwise None.

The token IDs filtered by this method should be contiguous for each multimodal item, i.e. special tokens if any should be included.
"""
processor = self.get_processor()
if processor is not None and getattr(processor, 'mm_token_ids',
Expand Down
103 changes: 103 additions & 0 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,8 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
*accuracy_inputs[modality]["prompt"],
"--media",
*accuracy_inputs[modality]["media"],
# TODO: remove this once kv cache reuse is supported for all VLM models
"--disable_kv_cache_reuse",
]
# NOTE: Qwen2-VL and Qwen2-5-VL model need larger max_num_tokens for video.
if model_name in ["qwen2-vl-7b-instruct", "qwen2.5-vl-7b-instruct"
Expand Down Expand Up @@ -2510,6 +2512,94 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
_check_mem_usage(running_log, [peak, 0, 0, 0])


@pytest.mark.parametrize("modality", ["image", "video"])
@pytest.mark.parametrize("model_name,model_path", [
("llava-v1.6-mistral-7b", "llava-v1.6-mistral-7b-hf"),
("qwen2.5-vl-7b-instruct", "Qwen2.5-VL-7B-Instruct"),
])
def test_ptp_quickstart_multimodal_kv_cache_reuse(llm_root, llm_venv,
model_name, model_path,
modality):
# NOTE: individual tests need to be enabled in
# tests/integration/test_lists/qa/examples_test_list.txt

example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
test_data_root = Path(
os.path.join(llm_models_root(), "multimodals", "test_data"))
print(f"Accuracy test {model_name} {modality} mode with example inputs.")
if modality == "video" and model_name == "llava-v1.6-mistral-7b":
pytest.skip("Skipping video modality test for llava-v1.6-mistral-7b")

num_same_requests = 3 # test kv cache reuse with multiple same requests
accuracy_inputs = {
"image": {
"prompt": [
"Describe the natural environment in the image.",
] * num_same_requests,
"media": [
str(test_data_root / "seashore.png"),
] * num_same_requests,
},
"video": {
"prompt": [
"Tell me what you see in the video briefly.",
] * num_same_requests,
"media": [
str(test_data_root / "OAI-sora-tokyo-walk.mp4"),
] * num_same_requests,
},
}

expected_keywords = {
"llava-v1.6-mistral-7b": {
"image": [
["ocean", "sky", "large", "waves", "shore", "blue"],
] * num_same_requests,
},
"qwen2.5-vl-7b-instruct": {
"image": [
["dramatic", "moody", "ocean", "stormy", "sky", "waves"],
] * num_same_requests,
"video": [
["woman", "neon", "night", "jacket", "wet"],
] * num_same_requests,
},
}

cmd = [
str(example_root / "quickstart_multimodal.py"),
"--model_dir",
f"{llm_models_root()}/{model_path}",
"--modality",
modality,
"--prompt",
*accuracy_inputs[modality]["prompt"],
"--media",
*accuracy_inputs[modality]["media"],
"--max_batch_size", # single request at a time to test kv cache reuse
"1",
]
# NOTE: Qwen2-VL and Qwen2-5-VL model need larger max_num_tokens for video.
if model_name in ["qwen2-vl-7b-instruct", "qwen2.5-vl-7b-instruct"
] and modality == "video":
cmd.append("--max_num_tokens=16384")

output = llm_venv.run_cmd(cmd, caller=check_output)
match_ratio = 4.0 / 5
for prompt_output, prompt_keywords in zip(
parse_output(output), expected_keywords[model_name][modality]):
matches = [
keyword in prompt_output.lower() for keyword in prompt_keywords
]
obs_match_ratio = 1. * sum(matches) / len(matches)
print(
f"Prompt output: {prompt_output}\nExpected keywords: {prompt_keywords}\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} given threshold {match_ratio}"
)
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"

print("All answers are correct!")


@pytest.mark.parametrize("modality", ["image", "audio", "image_audio"])
def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
model_name = "Phi-4-multimodal-instruct"
Expand Down Expand Up @@ -2583,6 +2673,8 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
"--load_lora",
"--auto_model_name",
"Phi4MMForCausalLM",
# TODO: remove this once kv cache reuse is supported for Phi-4-multimodal
"--disable_kv_cache_reuse",
]
output = llm_venv.run_cmd(cmd, caller=check_output)

Expand Down Expand Up @@ -2683,7 +2775,12 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
cmd.append("--max_seq_len=4096")
cmd.append("--load_lora")
cmd.append("--auto_model_name")
# TODO: remove this once kv cache reuse is supported for Phi-4-multimodal
cmd.append("--disable_kv_cache_reuse")
cmd.append("Phi4MMForCausalLM")
elif model_name == "mistral-small-3.1-24b-instruct":
# TODO: remove this once kv cache reuse is supported for Mistral
cmd.append("--disable_kv_cache_reuse")

output = llm_venv.run_cmd(cmd, caller=check_output)

Expand Down Expand Up @@ -2784,6 +2881,12 @@ def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
cmd.append("--load_lora")
cmd.append("--auto_model_name")
cmd.append("Phi4MMForCausalLM")
# TODO: remove this once kv cache reuse is supported for Phi-4
cmd.append("--disable_kv_cache_reuse")

elif model_name == "mistral-small-3.1-24b-instruct":
# TODO: remove this once kv cache reuse is supported for Mistral
cmd.append("--disable_kv_cache_reuse")

output = llm_venv.run_cmd(cmd, caller=check_output)
print("output:", output)
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,9 @@ test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistr
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image]
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image]
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
Expand Down