Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1752239
WIP: consider num_attention_layers for kv cache estimation and add ma…
tomeras91 Jun 12, 2025
7829ec9
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 18, 2025
4403183
organize code and logging for max batch size calculation for trtllm-b…
tomeras91 Jun 19, 2025
6ff4602
consider only attention layers when estimating number of tokens in Kv…
tomeras91 Jun 19, 2025
e6615a8
propagate kv_cache_gpu_mem_fraction to calc_engine_setting for trtllm…
tomeras91 Jun 19, 2025
42d65f3
release mamba cache memory when shutting down MambaCacheManager (and …
tomeras91 Jun 19, 2025
17d22e5
small refactor - MambaCacheManager method names to match BaseResource…
tomeras91 Jun 19, 2025
7dfeab8
refactor - is_nemotron_hybrid works on dicts as well
tomeras91 Jun 19, 2025
ee85bac
remove log
tomeras91 Jun 19, 2025
d0d0b7e
Add comment explaining squaring of kv_cache_gpu_mem_fraction + save r…
tomeras91 Jun 19, 2025
63bea92
remove debug print
tomeras91 Jun 19, 2025
c8c71df
fix - use config.get() only if config is a dict
tomeras91 Jun 19, 2025
3e6a30e
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 24, 2025
83e0673
optimistic tune max batch size only if not mamba attention hybrid model
tomeras91 Jun 25, 2025
4b2ba21
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 25, 2025
e6e65fc
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 26, 2025
8cf5ee7
Merge branch 'fix-trtllm-bench-for-nemotron-h' of github.com:tomeras9…
tomeras91 Jun 26, 2025
aa5d87c
fix: Mamba cache size estimation for FP8 - always use NO_QUANT for ma…
tomeras91 Jun 26, 2025
ac481b2
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jun 26, 2025
7904672
introduce NemotronHybridConfig that inherits from ModelConfig
tomeras91 Jul 2, 2025
04cba88
Move logic to compute extra model class to ModelConfig class
tomeras91 Jul 2, 2025
337e7aa
refactor max batch size estimation and make it more general (less mam…
tomeras91 Jul 2, 2025
4b0182b
remove redundant MambaConfig
tomeras91 Jul 2, 2025
ea4e816
simplify computation of total kv cache memory
tomeras91 Jul 2, 2025
1975d38
remove whitespace
tomeras91 Jul 2, 2025
1670ad9
compute cache memory fraction in ModelConfig + enable_optimistic_tuni…
tomeras91 Jul 2, 2025
3e40792
reduce formatting diff
tomeras91 Jul 2, 2025
0a0d2c8
Merge branch 'main' into fix-trtllm-bench-for-nemotron-h
tomeras91 Jul 3, 2025
47b9eb8
Add get_num_attention_layers() function in _torch/model_config.py::Mo…
tomeras91 Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import transformers

from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
from tensorrt_llm._utils import torch_dtype_to_binding
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
from tensorrt_llm.functional import AllReduceStrategy
Expand Down Expand Up @@ -298,7 +299,7 @@ def get_bindings_model_config(self,
model_config_cpp = ModelConfigCpp(
vocab_size=self.pretrained_config.vocab_size,
num_layers=self.pretrained_config.num_hidden_layers,
num_attention_layers=self.pretrained_config.num_hidden_layers,
num_attention_layers=self.get_num_attention_layers(),
num_rnn_layers=0,
num_heads=num_heads,
hidden_size=hidden_size,
Expand Down Expand Up @@ -376,3 +377,9 @@ def get_layer_types(self) -> Optional[List[LayerTypeCpp]]:
] * self.pretrained_config.num_hidden_layers
else:
return None

def get_num_attention_layers(self):
if is_nemotron_hybrid(self.pretrained_config):
return self.pretrained_config.hybrid_override_pattern.count("*")
else:
return self.pretrained_config.num_hidden_layers
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def _get_cache_size_per_token(model_config: ModelConfig,
) * num_key_value_heads // tp_size

# provide at least 1 layer to prevent division by zero cache size
num_hidden_layers = max(
len(mapping.pp_layers(config.num_hidden_layers)), 1)
mem_per_token *= num_hidden_layers * head_dim
num_attention_layers = max(
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
mem_per_token *= num_attention_layers * head_dim
# K and V
mem_per_token *= kv_factor
return mem_per_token
Expand Down
37 changes: 23 additions & 14 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def __init__(
device=device,
dtype=torch.int32)

def prepare_mamba_cache_blocks(self, request_ids: List[int]):
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
state_indices = []
for r in request_ids:
# cache hit
Expand All @@ -832,23 +832,21 @@ def prepare_mamba_cache_blocks(self, request_ids: List[int]):
self.state_indices[:len(state_indices)] = torch.as_tensor(
state_indices, dtype=torch.int32, device=self.ssm_states.device)

def free_mamba_cache_blocks(self, request_id: int):
if request_id in self.mamba_cache_index:
block = self.mamba_cache_index.pop(request_id)
self.mamba_cache_free_blocks.append(block)

def prepare_mamba_resources(self, scheduled_batch: ScheduledRequests):
def prepare_resources(self, scheduled_batch: ScheduledRequests):
context_ids = [
i.py_request_id for i in scheduled_batch.context_requests
]
generation_ids = [
i.py_request_id for i in scheduled_batch.generation_requests
]
request_ids = context_ids + generation_ids
self.prepare_mamba_cache_blocks(request_ids)
self._prepare_mamba_cache_blocks(request_ids)

def free_mamba_resources(self, request: LlmRequest):
self.free_mamba_cache_blocks(request.py_request_id)
def free_resources(self, request: LlmRequest):
request_id = request.py_request_id
if request_id in self.mamba_cache_index:
block = self.mamba_cache_index.pop(request_id)
self.mamba_cache_free_blocks.append(block)

def get_state_indices(self) -> torch.Tensor:
return self.state_indices
Expand All @@ -861,6 +859,13 @@ def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.ssm_states[layer_offset]

def shutdown(self):
# release tensor memory, keeping python references as tensors
self.conv_states = torch.tensor([])
self.ssm_states = torch.tensor([])
self.state_indices = torch.tensor([])
torch.cuda.empty_cache()


class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):

Expand Down Expand Up @@ -931,12 +936,16 @@ def __init__(
)

def prepare_resources(self, scheduled_batch: ScheduledRequests):
self.prepare_mamba_resources(scheduled_batch)
super().prepare_resources(scheduled_batch)
MambaCacheManager.prepare_resources(self, scheduled_batch)
KVCacheManager.prepare_resources(self, scheduled_batch)

def free_resources(self, request: LlmRequest):
self.free_mamba_resources(request)
super().free_resources(request)
MambaCacheManager.free_resources(self, request)
KVCacheManager.free_resources(self, request)

def shutdown(self):
MambaCacheManager.shutdown(self)
KVCacheManager.shutdown(self)


class SlotManager:
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/bench/benchmark/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
params.get("pp"),
dataset_metadata.avg_isl,
dataset_metadata.avg_osl,
params.get("kv_cache_free_gpu_mem_fraction"),
)

logger.info(
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/bench/build/build.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations
from transformers import AutoConfig

from pathlib import Path
from typing import Tuple, get_args
import click
from click_option_group import AllOptionGroup, optgroup

from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
Expand All @@ -13,7 +15,7 @@
from tensorrt_llm.llmapi.llm_utils import QuantConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.bench.build.dataclasses import ModelConfig
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
from tensorrt_llm.bench.build.tuning import calc_engine_setting

TUNED_QUANTS = {
Expand All @@ -31,6 +33,7 @@ def get_benchmark_engine_settings(
pp_size: int,
target_input_len: int,
target_output_len: int,
kv_cache_gpu_mem_fraction: float = 0.95,
) -> Tuple[int, int]:
""" Retrieve benchmark settings for a specific model + configuration.

Expand Down Expand Up @@ -58,6 +61,7 @@ def get_benchmark_engine_settings(
pp_size,
target_input_len,
target_output_len,
kv_cache_gpu_mem_fraction,
)
else:
max_batch_size = DEFAULT_MAX_BATCH_SIZE
Expand All @@ -82,6 +86,10 @@ def get_model_config(model_name: str, model_path: Path = None) -> ModelConfig:
Raises:
ValueError: When model is not supported.
"""
if is_nemotron_hybrid(
AutoConfig.from_pretrained(model_path or model_name,
trust_remote_code=True)):
return NemotronHybridConfig.from_hf(model_name, model_path)
return ModelConfig.from_hf(model_name, model_path)


Expand Down
60 changes: 60 additions & 0 deletions tensorrt_llm/bench/build/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class ModelConfig(BaseModel):
AliasPath("text_config", "num_hidden_layers"),
AliasPath("language_config", "num_hidden_layers"),
))
num_attention_layers: Optional[int] = Field(default=None)
num_attention_heads: int = Field(validation_alias=AliasChoices(
"num_attention_heads",
"n_head",
Expand All @@ -148,6 +149,7 @@ class ModelConfig(BaseModel):
validation_alias=AliasChoices(
"head_size",
"head_dim",
"attention_head_dim",
AliasPath("text_config", "head_dim"),
))
max_position_embeddings: Optional[int] = Field(
Expand All @@ -171,6 +173,8 @@ def set_values_if_none(self):
self.num_key_value_heads = self.num_attention_heads
if self.head_size is None:
self.head_size = self.hidden_size // self.num_attention_heads
if self.num_attention_layers is None:
self.num_attention_layers = self.num_hidden_layers
return self

@classmethod
Expand All @@ -194,3 +198,59 @@ def from_hf(cls, model_hf_name, hf_model_path):
param_count = cls.get_param_count(model_hf_name, hf_model_path)

return cls(name=model_hf_name, param_count=param_count, **hf_config)

def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
return 0

def cache_memory_fraction(self, cache_memory_fraction):
return cache_memory_fraction


class NemotronHybridConfig(ModelConfig):
hybrid_override_pattern: str
d_state: int = Field(validation_alias=AliasChoices(
"d_state",
"mamba_d_state",
"ssm_state_size",
))
d_conv: int = Field(validation_alias=AliasChoices(
"d_conv",
"mamba_d_conv",
"conv_kernel",
))
expand: int = Field(validation_alias=AliasChoices(
"expand",
"mamba_expand",
))
n_groups: int
mamba_head_dim: int
d_inner: Optional[int] = Field(default=None)
mamba_num_heads: Optional[int] = Field(default=None)
num_mamba_layers: Optional[int] = Field(default=None)

@model_validator(mode="after")
def set_values_if_none(self):
""" Set the values if cannot get values from HF config.json. """
if not self.d_inner:
self.d_inner = self.hidden_size * self.expand
if not self.mamba_num_heads:
self.mamba_num_heads = self.d_inner // self.mamba_head_dim
if self.num_mamba_layers is None:
self.num_mamba_layers = self.hybrid_override_pattern.count("M")
if self.num_attention_layers is None:
self.num_attention_layers = self.hybrid_override_pattern.count("*")

super().set_values_if_none()
return self

def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
conv_state_elems = conv_dim * (self.d_conv - 1)
ssm_state_elems = self.mamba_num_heads * self.mamba_head_dim * self.d_state
gb_per_mamba_cache = bytes_per_elem * self.num_mamba_layers * (
conv_state_elems + ssm_state_elems) / (1024**3)
return gb_per_mamba_cache

def cache_memory_fraction(self, cache_memory_fraction):
# Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size
return cache_memory_fraction**2
59 changes: 41 additions & 18 deletions tensorrt_llm/bench/build/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tensorrt_llm.llmapi.llm_utils import QuantConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.bench.build.dataclasses import ModelConfig
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
from .utils import get_device_memory
import math

Expand Down Expand Up @@ -55,7 +55,11 @@ def calc_engine_setting(

# Each GPU in TP group has at least 1 kv head
adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads)
byte_per_token = 2 * model_config.num_hidden_layers * adjusted_num_kv_heads \

logger.info(
f"Number of attention layers: {model_config.num_attention_layers}")

gb_per_token = 2 * model_config.num_attention_layers * adjusted_num_kv_heads \
* model_config.head_size * byte_per_kv_elem / (1024 ** 3)

# Number of GPU used for this run.
Expand All @@ -70,19 +74,33 @@ def calc_engine_setting(
f"{available_memory:.2f} GB")

# Calculate max requests in KV cache based on target ISL and OSL.
kv_cache_memory = available_memory * kv_cache_gpu_mem_fraction
kv_cache_max_tokens = kv_cache_memory / byte_per_token
kv_cache_max_requests = kv_cache_max_tokens / (target_input_len +
target_output_len)
logger.info(f"Estimated total KV cache memory: {kv_cache_memory:.2f} GB")
target_seq_len = target_input_len + target_output_len
cache_memory = available_memory * model_config.cache_memory_fraction(
kv_cache_gpu_mem_fraction)
gb_per_extra_cache = model_config.extra_model_cache_in_gb(
BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT), target_seq_len)
kv_cache_max_requests = cache_memory / (gb_per_token * target_seq_len +
gb_per_extra_cache)
extra_cache_memory = gb_per_extra_cache * kv_cache_max_requests
kv_cache_memory = cache_memory - extra_cache_memory
kv_cache_max_tokens = kv_cache_memory / gb_per_token

logger.info(
f"Estimated total cache memory: {cache_memory:.2f} GB. KV cache: {kv_cache_memory:.2f} GB, Extra cache: {extra_cache_memory:.2f} GB"
)
logger.info(f"Estimated kv cache max tokens: {kv_cache_max_tokens:.2f}")
logger.info("Estimated max number of requests in KV cache memory: "
f"{kv_cache_max_requests:.2f}")

# Fine-tune the max batch size and num token setting for performance.
max_batch_size, max_num_tokens = finetune_setting(kv_cache_max_requests,
target_input_len,
target_output_len,
pp_size)
# For mamba-attn hybrid models, we disable optimistic tuning because the mamba cache leaves less memory for the KV cache
max_batch_size, max_num_tokens = finetune_setting(
kv_cache_max_requests,
target_input_len,
target_output_len,
pp_size,
disable_optimistic_tuning=isinstance(model_config,
NemotronHybridConfig))

# Functional and performance
if total_gpu_memory < engine_size:
Expand All @@ -107,7 +125,7 @@ def calc_engine_setting(
if kv_cache_max_requests < 1:
raise RuntimeError("The amount of KV cache memory is insufficient to "
"run this model. Please try with more GPUs.")
if kv_cache_memory / n_gpus < 10.0:
if cache_memory / n_gpus < 10.0:
logger.warning(
f"The KV cache memory per GPU is less than 10 GB. "
"Performance may be undesirable. Please consider using a different "
Expand All @@ -126,6 +144,7 @@ def finetune_setting(
input_len: int,
output_len: int,
pp_size: int,
disable_optimistic_tuning: bool = False,
) -> Tuple[int, int]:
""" Calculate and fine-tune the engine build settings (max batch size and
max num tokens). Both max batch size and max num tokens are fine-tuned
Expand All @@ -137,6 +156,7 @@ def finetune_setting(
input_len (int): Input sequence length to compile the engine.
output_len (int): Output sequence length to compile the engine.
pp_size (int): Number of pipeline parallel stages.
disable_optimistic_tuning (bool): Whether to disable optimistic tuning.

Returns:
Tuple[int, int]: Tuple containing fine-tuned values for engine
Expand All @@ -148,13 +168,16 @@ def finetune_setting(
raw_token = min(raw_bs * (1 + input_len / output_len), 32768)

# Fine-tune the max batch size.
# Set min BS to be 64.
if raw_bs < 256:
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
elif raw_bs < 1024:
max_bs = 128 * math.ceil(raw_bs / 128)
if disable_optimistic_tuning:
max_bs = 2 * math.floor(raw_bs / 2)
else:
max_bs = 256 * math.ceil(raw_bs / 256)
# Set min BS to be 64.
if raw_bs < 256:
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
elif raw_bs < 1024:
max_bs = 128 * math.ceil(raw_bs / 128)
else:
max_bs = 256 * math.ceil(raw_bs / 256)

# Fine-tune the max num tokens.
# Set min to 2048 to ensure Ctx/Gen overlap efficiency
Expand Down