Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 87 additions & 51 deletions csrc/trtllm_fused_moe_kernel_launcher.cu

Large diffs are not rendered by default.

418 changes: 253 additions & 165 deletions csrc/trtllm_fused_moe_routing_deepseek.cu

Large diffs are not rendered by default.

227 changes: 143 additions & 84 deletions csrc/trtllm_fused_moe_routing_llama4.cu

Large diffs are not rendered by default.

322 changes: 239 additions & 83 deletions csrc/trtllm_fused_moe_routing_renormalize.cu

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mUsePdl = true;

// output:
routingData.mPtrExpertIdx = routingExpertIndexes;
routingData.mPtrTopKPacked = routingExpertIndexes;
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrExpertWeights = expertWeights;
routingData.mPtrTopKWeights = expertWeights;

routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
Expand Down Expand Up @@ -107,12 +107,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mUsePdl = true;

// output:
routingData.mPtrExpertIdx = routingExpertIndexes;
routingData.mPtrTopKPacked = routingExpertIndexes;
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrExpertWeights = expertWeights;
routingData.mPtrTopKWeights = expertWeights;

routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
Expand Down Expand Up @@ -149,12 +149,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
//
// Outputs
//
routingData.mPtrExpertIdx = routingExpertIndexes;
routingData.mPtrTopKPacked = routingExpertIndexes;
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrExpertWeights = expertWeights;
routingData.mPtrTopKWeights = expertWeights;

//
// Grouped Gemm Launch Config Buffers
Expand Down
53 changes: 29 additions & 24 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def _maybe_get_cached_w3_w1_permute_indices(
epilogue_tile_m: int,
num_elts_per_sf: Union[None, int] = None,
) -> torch.Tensor:
if dst_w3_w1_weight.shape not in _cache_permute_indices:
# Create a unique cache key (weight_type, weight_shape)
cache_key = ("w3_w1", dst_w3_w1_weight.shape)
if cache_key not in _cache_permute_indices:
# Get permute indices and chain them together
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight)
if num_elts_per_sf is None:
Expand All @@ -185,10 +187,10 @@ def _maybe_get_cached_w3_w1_permute_indices(
num_elts_per_sf=num_elts_per_sf,
)
# Memoize permute indices as recompute is **very** costly
_cache_permute_indices[dst_w3_w1_weight.shape] = permute0[permute1].to(
_cache_permute_indices[cache_key] = permute0[permute1].to(
dst_w3_w1_weight.device
)
permute_indices = _cache_permute_indices[dst_w3_w1_weight.shape]
permute_indices = _cache_permute_indices[cache_key]
return permute_indices


Expand All @@ -198,7 +200,9 @@ def get_w2_permute_indices_with_cache(
epilogue_tile_m: int,
num_elts_per_sf: Union[None, int] = None,
) -> torch.Tensor:
if dst_w2_weight.shape not in _cache_permute_indices:
# Create a unique cache key (weight_type, weight_shape)
cache_key = ("w2", dst_w2_weight.shape)
if cache_key not in _cache_permute_indices:
if num_elts_per_sf is None:
permute_indices = get_shuffle_matrix_a_row_indices(
dst_w2_weight, epilogue_tile_m
Expand All @@ -210,8 +214,8 @@ def get_w2_permute_indices_with_cache(
num_elts_per_sf=num_elts_per_sf,
).to(dst_w2_weight.device)
# Memoize permute indices as recompute is **very** costly
_cache_permute_indices[dst_w2_weight.shape] = permute_indices
permute_indices = _cache_permute_indices[dst_w2_weight.shape]
_cache_permute_indices[cache_key] = permute_indices
permute_indices = _cache_permute_indices[cache_key]
return permute_indices


Expand Down Expand Up @@ -1073,12 +1077,12 @@ def trtllm_fp8_per_tensor_scale_moe_op(
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
use_routing_scales_on_input: bool,
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
Expand Down Expand Up @@ -1127,12 +1131,12 @@ def _fake_trtllm_fp8_per_tensor_scale_moe(
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
use_routing_scales_on_input: bool,
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
Expand All @@ -1159,12 +1163,12 @@ def trtllm_fp8_block_scale_moe_op(
output: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
tile_tokens_dim: int,
routing_method_type: int,
use_shuffled_weight: bool = False,
Expand All @@ -1173,6 +1177,7 @@ def trtllm_fp8_block_scale_moe_op(
) -> torch.Tensor:
if enable_pdl is None:
enable_pdl = device_support_pdl(hidden_states.device)

# Call the C++ function for block scale MoE
moe_op.trtllm_fp8_block_scale_moe(
routing_logits,
Expand Down Expand Up @@ -1214,12 +1219,12 @@ def _fake_trtllm_fp8_block_scale_moe(
output: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
use_shuffled_weight: bool = False,
Expand Down Expand Up @@ -1479,12 +1484,12 @@ def trtllm_fp8_per_tensor_scale_moe(
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
use_routing_scales_on_input: bool,
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
Expand Down Expand Up @@ -1552,12 +1557,12 @@ def trtllm_fp8_block_scale_moe(
gemm2_weights_scale: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routed_scaling_factor: Optional[float],
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
use_shuffled_weight: bool = False,
Expand Down
Loading