Skip to content

Commit ed83138

Browse files
committed
fix wrong size calculation
Signed-off-by: jiahanc <[email protected]>
1 parent 6d59e6f commit ed83138

File tree

5 files changed

+32
-35
lines changed

5 files changed

+32
-35
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ void TrtllmGenBatchedGemmRunner::run(
169169
auto const configs = bmm.getBatchedGemmConfigs();
170170

171171
auto const& config = configs[configIndex];
172-
std::cout << "config.mFunctionName: " << config.mFunctionName << std::endl;
173172
FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0");
174173
if (!mOptions.staticBatch) {
175174
FLASHINFER_CHECK(totalNumPaddedTokens,

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
6262

6363
if (use_routing_scales_on_input) {
6464
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16.";
65-
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
66-
RoutingMethodType::DeepSeekV3) {
65+
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
6766
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_float32) << "routing_logits must be float.";
6867
} else {
6968
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16.";
@@ -99,8 +98,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
9998
RoutingMethodType::RenormalizeNaive) {
10099
TVM_FFI_LOG_AND_THROW(NotImplementedError)
101100
<< "Don't support routing method type Renormalize(Naive).";
102-
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
103-
RoutingMethodType::Llama4) {
101+
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
104102
TVM_FFI_ICHECK_EQ(top_k, 1)
105103
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";
106104
}
@@ -144,7 +142,8 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
144142
args.topk_group = topk_group.has_value() ? topk_group.value() : 0;
145143
args.local_expert_offset = local_expert_offset;
146144
args.local_num_experts = local_num_experts;
147-
args.routed_scaling_factor = routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0;
145+
args.routed_scaling_factor =
146+
routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0;
148147
args.intermediate_size = intermediate_size;
149148
args.mUseRoutingScalesOnInput = use_routing_scales_on_input;
150149

@@ -300,9 +299,10 @@ void trtllm_fp8_per_tensor_scale_moe(
300299
TensorView gemm1_weights, TensorView output1_scales_scalar,
301300
TensorView output1_scales_gate_scalar, TensorView gemm2_weights,
302301
TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k,
303-
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size, int64_t local_expert_offset,
304-
int64_t local_num_experts, Optional<double> routed_scaling_factor, bool use_routing_scales_on_input,
305-
int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) {
302+
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
303+
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
304+
bool use_routing_scales_on_input, int64_t tile_tokens_dim, int64_t routing_method_type,
305+
bool enable_pdl) {
306306
auto dtype = hidden_states->dtype;
307307
if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
308308
trtllm_fp8_per_tensor_scale_moe_launcher(
@@ -320,10 +320,11 @@ void trtllm_fp8_block_scale_moe_launcher(
320320
TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
321321
TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale,
322322
TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output,
323-
int64_t const num_experts, int64_t const top_k, Optional<int64_t> const n_group, Optional<int64_t> const topk_group,
324-
int64_t const intermediate_size, int64_t const local_expert_offset,
325-
int64_t const local_num_experts, Optional<double> const routed_scaling_factor,
326-
int64_t const tile_tokens_dim, int64_t const routing_method_type,
323+
int64_t const num_experts, int64_t const top_k, Optional<int64_t> const n_group,
324+
Optional<int64_t> const topk_group, int64_t const intermediate_size,
325+
int64_t const local_expert_offset, int64_t const local_num_experts,
326+
Optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
327+
int64_t const routing_method_type,
327328
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex,
328329
bool enable_pdl) {
329330
static const std::tuple<int, int> device_props = [hidden_states] {
@@ -380,8 +381,7 @@ void trtllm_fp8_block_scale_moe_launcher(
380381
RoutingMethodType::RenormalizeNaive) {
381382
TVM_FFI_ICHECK(top_k <= 10 && top_k > 0)
382383
<< "Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0.";
383-
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
384-
RoutingMethodType::Llama4) {
384+
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
385385
TVM_FFI_ICHECK_EQ(top_k, 1)
386386
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";
387387
}
@@ -424,7 +424,8 @@ void trtllm_fp8_block_scale_moe_launcher(
424424
args.topk_group = topk_group.has_value() ? topk_group.value() : 0;
425425
args.local_expert_offset = local_expert_offset;
426426
args.local_num_experts = local_num_experts;
427-
args.routed_scaling_factor = routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0;
427+
args.routed_scaling_factor =
428+
routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0;
428429
args.intermediate_size = intermediate_size;
429430
args.mUseDeepSeekFp8 = true;
430431

@@ -610,11 +611,11 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional<TensorView>
610611
TensorView gemm1_weights, TensorView gemm1_weights_scale,
611612
TensorView gemm2_weights, TensorView gemm2_weights_scale,
612613
TensorView output, int64_t num_experts, int64_t top_k,
613-
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
614-
int64_t local_expert_offset, int64_t local_num_experts,
615-
Optional<double> routed_scaling_factor, int64_t tile_tokens_dim,
616-
int64_t routing_method_type, bool use_shuffled_weight,
617-
int64_t weight_layout, bool enable_pdl) {
614+
Optional<int64_t> n_group, Optional<int64_t> topk_group,
615+
int64_t intermediate_size, int64_t local_expert_offset,
616+
int64_t local_num_experts, Optional<double> routed_scaling_factor,
617+
int64_t tile_tokens_dim, int64_t routing_method_type,
618+
bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl) {
618619
auto dtype = hidden_states->dtype;
619620
if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
620621
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
@@ -829,11 +830,9 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
829830
// {args.num_tokens, args.top_k}, routing_bias_dtype, hidden_states->device);
830831
// Tensor expert_indexes = alloc_tensor(
831832
// {args.num_tokens, args.top_k}, dl_int32, hidden_states->device);
832-
int constexpr MAX_NUM_EXPERTS = 384;
833-
Tensor expert_count_histogram = alloc_tensor(
834-
{2 * MAX_NUM_EXPERTS},
835-
dl_int32, // 256 is the max number of threads per block and max number of experts
836-
hidden_states->device);
833+
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
834+
Tensor expert_count_histogram =
835+
alloc_tensor({size_of_expert_count_histogram}, dl_int32, hidden_states->device);
837836

838837
auto const sf_vec_size = dtype_weights == btg::Dtype::MxE2m1 ? 32 : 16;
839838

@@ -1035,7 +1034,6 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
10351034
workspace.gemm1_output_scale = gemm1_output_scale.has_value()
10361035
? static_cast<float*>(gemm1_output_scale.value()->data)
10371036
: nullptr;
1038-
10391037
// gemm2 intermediate ws
10401038
workspace.gemm2_output = gemm2_output->data;
10411039
workspace.gemm2_output_scale = nullptr;

include/flashinfer/trtllm/fused_moe/DevKernel.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ namespace moe::dev {
122122
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \
123123
numBlocks, numThreads, smemSize, stream); \
124124
} else { \
125-
FLASHINFER_WARN("Unsupported dtypeExpW"); \
125+
FLASHINFER_WARN("Unsupported dtypeExpW"); \
126126
}
127127

128128
#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \
@@ -147,7 +147,7 @@ namespace moe::dev {
147147
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \
148148
kernel, numBlocks, numThreads, smemSize, stream); \
149149
} else { \
150-
FLASHINFER_WARN("Unsupported dtypeExpW"); \
150+
FLASHINFER_WARN("Unsupported dtypeExpW"); \
151151
}
152152

153153
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -167,7 +167,7 @@ namespace moe::dev {
167167
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \
168168
kernel, numBlocks, numThreads, smemSize, stream); \
169169
} else { \
170-
FLASHINFER_WARN("Unsupported dtypeExpW"); \
170+
FLASHINFER_WARN("Unsupported dtypeExpW"); \
171171
}
172172

173173
////////////////////////////////////////////////////////////////////////////////////////////////////

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ def is_cuda_oom_error_str(e: str) -> bool:
137137
return "CUDA" in e and "out of memory" in e
138138

139139

140-
@pytest.hookimpl(hookwrapper=True)
140+
@pytest.hookimpl(tryfirst=True)
141141
def pytest_runtest_call(item):
142142
# Wrap the test call so we don't invoke item.runtest() ourselves; yield lets pytest run it.
143143
try:
144-
yield
144+
item.runtest()
145145
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
146146
if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)):
147147
pytest.skip("Skipping due to OOM")

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,7 +1838,7 @@ def cache_permute_indices():
18381838

18391839
@pytest.mark.parametrize("num_tokens", [1, 8, 1024])
18401840
@pytest.mark.parametrize("hidden_size", [1024, 8192])
1841-
@pytest.mark.parametrize("intermediate_size", [ 1024, 768, 384, 512])
1841+
@pytest.mark.parametrize("intermediate_size", [1024, 768, 384, 512])
18421842
@pytest.mark.parametrize(
18431843
"moe_impl",
18441844
[
@@ -1913,7 +1913,7 @@ def cache_permute_indices():
19131913
"routed_scaling": None,
19141914
"has_routing_bias": False,
19151915
"routing_method_type": RoutingMethodType.Renormalize,
1916-
"compatible_moe_impls": [FP8PerTensorMoe, FP8BlockScaleMoe, FP4Moe],
1916+
"compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe],
19171917
},
19181918
id="Renorm",
19191919
# marks=pytest.mark.skip(
@@ -2085,7 +2085,7 @@ def test_moe_quantization_classes(
20852085
)
20862086
else 64,
20872087
)
2088-
2088+
padding = tile_tokens_dim
20892089
# Validation checks
20902090
assert top_k <= num_experts
20912091
# assert top_k <= 8

0 commit comments

Comments
 (0)