Skip to content

Commit d471655

Browse files
authored
[TRTLLM-7831][feat] Cherry-pick from #7423 Support fp8 block wide ep cherry pick (#7712)
1 parent 59f5759 commit d471655

File tree

7 files changed

+1049
-22
lines changed

7 files changed

+1049
-22
lines changed

docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ There are multiple MOE backends inside TRT-LLM, not all of them supporting every
3030
| B200/GB200 EP<=8 | NVFP4 | CUTLASS, TRTLLM |
3131
| B200/GB200 EP<=8 | FP8 | DEEPGEMM |
3232
| GB200 NVL72 EP>8 | NVFP4 | WIDEEP |
33-
| GB200 NVL72 EP>8 | FP8 | N/A (WIP) |
33+
| GB200 NVL72 EP>8 | FP8 | WIDEEP without EPLB |
3434

3535
The default moe backend is `CUTLASS`, so for the combination which is not supported by `CUTLASS`, one must set the `moe_config.backend` explicitly to run the model.
3636

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import torch
66

77
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
8-
from tensorrt_llm._utils import logger
8+
from tensorrt_llm._utils import get_sm_version
99
from tensorrt_llm.functional import AllReduceStrategy
10+
from tensorrt_llm.logger import logger
1011
from tensorrt_llm.mapping import Mapping
1112

1213
from ...distributed import AllReduce, allgather, reducescatter
@@ -16,7 +17,9 @@
1617
from .deep_ep_utils import buffer_pool, deep_ep_installed
1718
from .interface import MoE
1819
from .moe_load_balancer import get_moe_load_balancer
20+
from .ops import MoEOp, MoEOpSelector
1921
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
22+
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
2023
FP8QDQFusedMoEMethod, MoEWeightLoadingMode,
2124
NVFP4CutlassFusedMoEMethod,
2225
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod)
@@ -90,6 +93,9 @@ def __init__(
9093
# If True, the router weight will be multiplied on the input rather than at the end of FC2
9194
self.apply_router_weight_on_input = apply_router_weight_on_input
9295

96+
# Store original hidden size before any potential padding
97+
self.unpadded_hidden_size = self.hidden_size
98+
9399
moe_load_balancer = get_moe_load_balancer()
94100
self.layer_load_balancer = None
95101
self.repeat_idx = 0
@@ -227,6 +233,9 @@ def __init__(
227233
self.enable_dummy_allreduce = os.environ.get(
228234
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
229235

236+
# MoE op will be lazily initialized when first accessed (see moe_op_impl property)
237+
self._moe_op_impl = None
238+
230239
def _check_configs(self):
231240
assert self._weights_created
232241

@@ -316,7 +325,10 @@ def _get_quant_method(self):
316325
if self.quant_config.layer_quant_mode.has_fp8_qdq():
317326
return FP8QDQFusedMoEMethod()
318327
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
319-
return DeepSeekFP8BlockScalesFusedMoEMethod()
328+
if get_sm_version() == 100:
329+
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
330+
else:
331+
return DeepSeekFP8BlockScalesFusedMoEMethod()
320332
elif self.quant_config.layer_quant_mode.has_nvfp4():
321333
return NVFP4CutlassFusedMoEMethod()
322334
elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
@@ -339,6 +351,19 @@ def create_weights(self):
339351
self._weights_created = True
340352
self._check_configs()
341353

354+
@property
355+
def moe_op_impl(self) -> MoEOp:
356+
"""
357+
Lazily initialize and return the MoE op.
358+
359+
The op is selected based on hardware capabilities and quantization
360+
configuration, which are only available after weights are created.
361+
"""
362+
if self._moe_op_impl is None:
363+
assert self._weights_created, "Weights must be created before accessing moe_op"
364+
self._moe_op_impl = MoEOpSelector.select_op(self)
365+
return self._moe_op_impl
366+
342367
def dummy_allreduce(self):
343368
"""
344369
Debug function for eliminating imbalance during performance analysis.
@@ -389,8 +414,9 @@ def forward_chunk(
389414
if self.layer_load_balancer and is_first_call:
390415
self.layer_load_balancer.start_wait_gpu_stage()
391416

392-
use_deepseek_fp8_block_scale = False
393-
use_w4_group_scaling = False
417+
if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
418+
pass
419+
394420
weight_dtype = self.w3_w1_weight.dtype
395421

396422
token_selected_experts, token_final_scales = self.routing_method.apply(
@@ -544,9 +570,8 @@ def forward_chunk(
544570
x_sf = x_sf.view((x_row, -1))
545571

546572
elif self.has_deepseek_fp8_block_scales:
547-
use_deepseek_fp8_block_scale = True
573+
pass
548574
elif self.has_w4afp8:
549-
use_w4_group_scaling = True
550575
weight_dtype = torch.quint4x2
551576
else:
552577
raise ValueError(
@@ -569,12 +594,8 @@ def forward_chunk(
569594
sizes=None if use_dp_padding else all_rank_num_tokens)
570595
x_row = x.shape[0]
571596

572-
ep_size = self.ep_size
573-
ep_rank = self.ep_rank
574597
w3_w1_weight = self.w3_w1_weight
575598
w2_weight = self.w2_weight
576-
cluster_size = self.cluster_size
577-
cluster_rank = self.cluster_rank
578599
quant_scales = self.quant_scales
579600

580601
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
@@ -640,7 +661,8 @@ def forward_chunk(
640661
f"Not available alltoall method type: {self.alltoall_method_type!r}"
641662
)
642663

643-
final_hidden_states = torch.ops.trtllm.fused_moe(
664+
final_hidden_states = self.moe_op_impl.run_moe(
665+
self,
644666
x,
645667
token_selected_slots,
646668
token_final_scales,
@@ -652,17 +674,8 @@ def forward_chunk(
652674
quant_scales=quant_scales,
653675
input_sf=x_sf,
654676
swizzled_input_sf=False,
655-
tp_size=self.tp_size,
656-
tp_rank=self.tp_rank,
657-
ep_size=ep_size,
658-
ep_rank=ep_rank,
659-
cluster_size=cluster_size,
660-
cluster_rank=cluster_rank,
661-
enable_alltoall=use_all_to_all,
662-
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
663-
use_w4_group_scaling=use_w4_group_scaling,
664677
min_latency_mode=False,
665-
tune_max_num_tokens=self.tune_max_num_tokens,
678+
use_fused_finalize=True,
666679
tuner_num_tokens=tuner_num_tokens,
667680
tuner_top_k=tuner_top_k,
668681
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""MoE ops module for different computation implementations."""
12+
13+
from .moe_op import MoEOp, MoEOpSelector
14+
from .moe_op_cutlass import CutlassMoEOp
15+
from .moe_op_deepgemm import DeepGemmMoEOp
16+
17+
__all__ = ['MoEOp', 'MoEOpSelector', 'CutlassMoEOp', 'DeepGemmMoEOp']
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
MoE Op abstraction for supporting different MoE computation implementations.
13+
This module provides a unified interface for different MoE ops (Cutlass, DeepGemm, etc.)
14+
"""
15+
16+
from abc import ABC, abstractmethod
17+
from typing import TYPE_CHECKING, List, Optional
18+
19+
import torch
20+
21+
from tensorrt_llm._utils import get_sm_version
22+
23+
if TYPE_CHECKING:
24+
from ..interface import MoE
25+
26+
27+
class MoEOp(ABC):
28+
"""Abstract base class for MoE computation ops.
29+
30+
This class provides a strategy pattern for different MoE computation implementations.
31+
It is used by MoE modules (like WideEPMoE) to delegate the actual computation.
32+
33+
Note: MoEOp is NOT a MoE module itself, but a computation strategy.
34+
The actual MoE module (e.g., WideEPMoE) inherits from MoE and uses MoEOp
35+
for the computation implementation.
36+
"""
37+
38+
# Op-specific abstract methods
39+
@abstractmethod
40+
def finalize_tactic(
41+
self,
42+
module: 'MoE',
43+
tuner_input: torch.Tensor,
44+
output_dtype: torch.dtype,
45+
min_latency_mode: bool = False,
46+
use_fused_finalize: bool = True,
47+
tuner_top_k: Optional[int] = None,
48+
) -> None:
49+
"""
50+
Finalize tactics for the MoE computation.
51+
For Cutlass op, this includes profiling and tactic selection.
52+
For DeepGemm op, this can be a no-op.
53+
54+
Args:
55+
module: The MoE module containing MoE configurations
56+
tuner_input: Real input used for tuning (same shape/layout as non-alltoall)
57+
output_dtype: Output dtype for tuner run
58+
min_latency_mode: Whether to profile for min-latency path
59+
use_fused_finalize: Whether to use fused finalize
60+
tuner_top_k: Top-k value for tuning (Cutlass specific)
61+
"""
62+
63+
@abstractmethod
64+
def compute_moe(
65+
self,
66+
module: 'MoE',
67+
# Input tensors
68+
x: torch.Tensor,
69+
token_selected_slots: torch.Tensor,
70+
token_final_scales: Optional[torch.Tensor],
71+
# Weight tensors
72+
w3_w1_weight: torch.Tensor,
73+
w3_w1_bias: Optional[torch.Tensor],
74+
w2_weight: torch.Tensor,
75+
w2_bias: Optional[torch.Tensor],
76+
# Output configuration
77+
output_dtype: torch.dtype,
78+
# Quantization parameters
79+
quant_scales: List[torch.Tensor],
80+
input_sf: Optional[torch.Tensor] = None,
81+
swizzled_input_sf: bool = True,
82+
# Performance tuning (only runtime-variable parameters)
83+
min_latency_mode: bool = False,
84+
use_fused_finalize: bool = True,
85+
tuner_num_tokens: Optional[int] = None,
86+
tuner_top_k: Optional[int] = None,
87+
**kwargs) -> torch.Tensor:
88+
"""
89+
Perform the actual MoE computation.
90+
91+
Configuration parameters (tp_size, ep_size, swiglu params, etc.) are
92+
automatically extracted from the module parameter.
93+
94+
Args:
95+
module: MoE module containing configuration and parameters.
96+
x: Input tensor
97+
token_selected_slots: Selected expert slots
98+
token_final_scales: Scaling factors
99+
w3_w1_weight: Fused gate and up projection weights
100+
w3_w1_bias: Optional bias
101+
w2_weight: Down projection weights
102+
w2_bias: Optional bias
103+
output_dtype: Output data type
104+
quant_scales: Quantization scales
105+
input_sf: Input scaling factor
106+
swizzled_input_sf: Whether input_sf is swizzled
107+
min_latency_mode: Use minimum latency optimizations
108+
use_fused_finalize: Use fused finalization
109+
tuner_num_tokens: Number of tokens for tuning
110+
tuner_top_k: Top-k value for tuning
111+
112+
Returns:
113+
Computed MoE output tensor
114+
"""
115+
116+
def run_moe(
117+
self,
118+
module: 'MoE',
119+
# Input tensors
120+
input: torch.Tensor,
121+
token_selected_slots: torch.Tensor,
122+
token_final_scales: torch.Tensor,
123+
w3_w1_weight: torch.Tensor,
124+
w3_w1_bias: Optional[torch.Tensor],
125+
w2_weight: torch.Tensor,
126+
w2_bias: Optional[torch.Tensor],
127+
output_dtype: torch.dtype,
128+
# Quantization parameters
129+
quant_scales: List[torch.Tensor],
130+
input_sf: Optional[torch.Tensor] = None,
131+
swizzled_input_sf: bool = True,
132+
# Performance tuning (only runtime-variable parameters)
133+
min_latency_mode: bool = False,
134+
use_fused_finalize: bool = True,
135+
tuner_num_tokens: Optional[int] = None,
136+
tuner_top_k: Optional[int] = None,
137+
**kwargs) -> torch.Tensor:
138+
"""
139+
Run the complete MoE computation pipeline.
140+
141+
Configuration parameters are automatically extracted from the module.
142+
143+
Args:
144+
module: MoE module containing configuration
145+
input: Input tensor to the MoE layer
146+
token_selected_slots: Selected expert slots for each token
147+
token_final_scales: Final scaling factors for each token
148+
w3_w1_weight: Concatenated weights for w3 and w1 projections
149+
w3_w1_bias: Optional bias for w3/w1 projections
150+
w2_weight: Weight for w2 projection
151+
w2_bias: Optional bias for w2 projection
152+
output_dtype: Desired output data type
153+
quant_scales: Quantization scales for weights
154+
input_sf: Optional input scale factors for quantization
155+
swizzled_input_sf: Whether input scale factors are swizzled
156+
min_latency_mode: Use minimum latency optimizations
157+
use_fused_finalize: Use fused finalization
158+
tuner_num_tokens: Number of tokens for tuner input
159+
tuner_top_k: Top-k value for tuning
160+
161+
Returns:
162+
Computed MoE output tensor
163+
"""
164+
self.finalize_tactic(module, input, output_dtype, min_latency_mode,
165+
use_fused_finalize, tuner_top_k)
166+
167+
# Call compute_moe with module
168+
return self.compute_moe(module=module,
169+
x=input,
170+
token_selected_slots=token_selected_slots,
171+
token_final_scales=token_final_scales,
172+
w3_w1_weight=w3_w1_weight,
173+
w3_w1_bias=w3_w1_bias,
174+
w2_weight=w2_weight,
175+
w2_bias=w2_bias,
176+
output_dtype=output_dtype,
177+
quant_scales=quant_scales,
178+
input_sf=input_sf,
179+
swizzled_input_sf=swizzled_input_sf,
180+
min_latency_mode=min_latency_mode,
181+
use_fused_finalize=use_fused_finalize,
182+
tuner_num_tokens=tuner_num_tokens,
183+
tuner_top_k=tuner_top_k,
184+
**kwargs)
185+
186+
187+
class MoEOpSelector:
188+
"""
189+
Utility class for selecting the appropriate MoE op based on
190+
hardware capabilities and quantization configuration.
191+
192+
This class implements the strategy pattern for op selection,
193+
choosing between Cutlass and DeepGemm implementations based on:
194+
- Hardware capabilities (SM version)
195+
- Quantization configuration (block FP8 support)
196+
"""
197+
198+
@staticmethod
199+
def select_op(module: 'MoE') -> MoEOp:
200+
"""
201+
Select the appropriate MoE op based on module configuration.
202+
203+
Selection criteria:
204+
- Blackwell (SM100) with block FP8 quantization -> DeepGemm op
205+
- All other configurations -> Cutlass op
206+
207+
Args:
208+
module: The MoE module containing configuration information
209+
210+
Returns:
211+
MoEOp: Selected op instance (CutlassMoEOp or DeepGemmMoEOp)
212+
213+
Example:
214+
>>> op = MoEOpSelector.select_op(moe_module)
215+
>>> output = op.run_moe(input, ...)
216+
"""
217+
from .moe_op_cutlass import CutlassMoEOp
218+
from .moe_op_deepgemm import DeepGemmMoEOp
219+
220+
# Check if we should use DeepGemm op
221+
# Blackwell has SM version 100
222+
is_blackwell = get_sm_version() == 100
223+
has_block_fp8 = module.has_deepseek_fp8_block_scales
224+
225+
if is_blackwell and has_block_fp8:
226+
# Use DeepGemm op for Blackwell with block FP8
227+
return DeepGemmMoEOp()
228+
else:
229+
# Use Cutlass op for all other cases
230+
return CutlassMoEOp()

0 commit comments

Comments
 (0)