@@ -62,8 +62,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
6262
6363 if (use_routing_scales_on_input) {
6464 TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_bfloat16) << " routing_logits must be bfloat16." ;
65- } else if (static_cast <RoutingMethodType>(routing_method_type) ==
66- RoutingMethodType::DeepSeekV3) {
65+ } else if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
6766 TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_float32) << " routing_logits must be float." ;
6867 } else {
6968 TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_bfloat16) << " routing_logits must be bfloat16." ;
@@ -99,8 +98,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
9998 RoutingMethodType::RenormalizeNaive) {
10099 TVM_FFI_LOG_AND_THROW (NotImplementedError)
101100 << " Don't support routing method type Renormalize(Naive)." ;
102- } else if (static_cast <RoutingMethodType>(routing_method_type) ==
103- RoutingMethodType::Llama4) {
101+ } else if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
104102 TVM_FFI_ICHECK_EQ (top_k, 1 )
105103 << " Current routing kernel (no groups, Llama4) only supports top_k=1." ;
106104 }
@@ -144,7 +142,8 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
144142 args.topk_group = topk_group.has_value () ? topk_group.value () : 0 ;
145143 args.local_expert_offset = local_expert_offset;
146144 args.local_num_experts = local_num_experts;
147- args.routed_scaling_factor = routed_scaling_factor.has_value () ? routed_scaling_factor.value () : 1.0 ;
145+ args.routed_scaling_factor =
146+ routed_scaling_factor.has_value () ? routed_scaling_factor.value () : 1.0 ;
148147 args.intermediate_size = intermediate_size;
149148 args.mUseRoutingScalesOnInput = use_routing_scales_on_input;
150149
@@ -300,9 +299,10 @@ void trtllm_fp8_per_tensor_scale_moe(
300299 TensorView gemm1_weights, TensorView output1_scales_scalar,
301300 TensorView output1_scales_gate_scalar, TensorView gemm2_weights,
302301 TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k,
303- Optional<int64_t > n_group, Optional<int64_t > topk_group, int64_t intermediate_size, int64_t local_expert_offset,
304- int64_t local_num_experts, Optional<double > routed_scaling_factor, bool use_routing_scales_on_input,
305- int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) {
302+ Optional<int64_t > n_group, Optional<int64_t > topk_group, int64_t intermediate_size,
303+ int64_t local_expert_offset, int64_t local_num_experts, Optional<double > routed_scaling_factor,
304+ bool use_routing_scales_on_input, int64_t tile_tokens_dim, int64_t routing_method_type,
305+ bool enable_pdl) {
306306 auto dtype = hidden_states->dtype ;
307307 if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
308308 trtllm_fp8_per_tensor_scale_moe_launcher (
@@ -320,10 +320,11 @@ void trtllm_fp8_block_scale_moe_launcher(
320320 TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
321321 TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale,
322322 TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output,
323- int64_t const num_experts, int64_t const top_k, Optional<int64_t > const n_group, Optional<int64_t > const topk_group,
324- int64_t const intermediate_size, int64_t const local_expert_offset,
325- int64_t const local_num_experts, Optional<double > const routed_scaling_factor,
326- int64_t const tile_tokens_dim, int64_t const routing_method_type,
323+ int64_t const num_experts, int64_t const top_k, Optional<int64_t > const n_group,
324+ Optional<int64_t > const topk_group, int64_t const intermediate_size,
325+ int64_t const local_expert_offset, int64_t const local_num_experts,
326+ Optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim,
327+ int64_t const routing_method_type,
327328 tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex,
328329 bool enable_pdl) {
329330 static const std::tuple<int , int > device_props = [hidden_states] {
@@ -380,8 +381,7 @@ void trtllm_fp8_block_scale_moe_launcher(
380381 RoutingMethodType::RenormalizeNaive) {
381382 TVM_FFI_ICHECK (top_k <= 10 && top_k > 0 )
382383 << " Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0." ;
383- } else if (static_cast <RoutingMethodType>(routing_method_type) ==
384- RoutingMethodType::Llama4) {
384+ } else if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
385385 TVM_FFI_ICHECK_EQ (top_k, 1 )
386386 << " Current routing kernel (no groups, Llama4) only supports top_k=1." ;
387387 }
@@ -424,7 +424,8 @@ void trtllm_fp8_block_scale_moe_launcher(
424424 args.topk_group = topk_group.has_value () ? topk_group.value () : 0 ;
425425 args.local_expert_offset = local_expert_offset;
426426 args.local_num_experts = local_num_experts;
427- args.routed_scaling_factor = routed_scaling_factor.has_value () ? routed_scaling_factor.value () : 1.0 ;
427+ args.routed_scaling_factor =
428+ routed_scaling_factor.has_value () ? routed_scaling_factor.value () : 1.0 ;
428429 args.intermediate_size = intermediate_size;
429430 args.mUseDeepSeekFp8 = true ;
430431
@@ -610,11 +611,11 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional<TensorView>
610611 TensorView gemm1_weights, TensorView gemm1_weights_scale,
611612 TensorView gemm2_weights, TensorView gemm2_weights_scale,
612613 TensorView output, int64_t num_experts, int64_t top_k,
613- Optional<int64_t > n_group, Optional<int64_t > topk_group, int64_t intermediate_size,
614- int64_t local_expert_offset , int64_t local_num_experts ,
615- Optional<double > routed_scaling_factor, int64_t tile_tokens_dim ,
616- int64_t routing_method_type, bool use_shuffled_weight ,
617- int64_t weight_layout, bool enable_pdl) {
614+ Optional<int64_t > n_group, Optional<int64_t > topk_group,
615+ int64_t intermediate_size , int64_t local_expert_offset ,
616+ int64_t local_num_experts, Optional<double > routed_scaling_factor,
617+ int64_t tile_tokens_dim, int64_t routing_method_type ,
618+ bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl) {
618619 auto dtype = hidden_states->dtype ;
619620 if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
620621 using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
@@ -829,11 +830,9 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
829830 // {args.num_tokens, args.top_k}, routing_bias_dtype, hidden_states->device);
830831 // Tensor expert_indexes = alloc_tensor(
831832 // {args.num_tokens, args.top_k}, dl_int32, hidden_states->device);
832- int constexpr MAX_NUM_EXPERTS = 384 ;
833- Tensor expert_count_histogram = alloc_tensor (
834- {2 * MAX_NUM_EXPERTS},
835- dl_int32, // 256 is the max number of threads per block and max number of experts
836- hidden_states->device );
833+ int64_t const size_of_expert_count_histogram = std::max (num_experts * 2 , int64_t (256 * 2 ));
834+ Tensor expert_count_histogram =
835+ alloc_tensor ({size_of_expert_count_histogram}, dl_int32, hidden_states->device );
837836
838837 auto const sf_vec_size = dtype_weights == btg::Dtype::MxE2m1 ? 32 : 16 ;
839838
@@ -1035,7 +1034,6 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
10351034 workspace.gemm1_output_scale = gemm1_output_scale.has_value ()
10361035 ? static_cast <float *>(gemm1_output_scale.value ()->data )
10371036 : nullptr ;
1038-
10391037 // gemm2 intermediate ws
10401038 workspace.gemm2_output = gemm2_output->data ;
10411039 workspace.gemm2_output_scale = nullptr ;
0 commit comments