Skip to content

Commit 2b51947

Browse files
committed
[BugFix] Retain original dtype for topk_weights
Signed-off-by: huanghaoyan.hhy <[email protected]>
1 parent ff33895 commit 2b51947

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def flashinfer_alltoall_dispatch(
230230
max_num_token = (
231231
max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
232232
)
233+
topk_weights_dtype = topk_weights.dtype
233234
alltoall_info, topk_ids, topk_weights, _ = (
234235
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
235236
topk_ids,
@@ -244,6 +245,9 @@ def flashinfer_alltoall_dispatch(
244245
top_k,
245246
)
246247
)
248+
# NOTE: Restore original dtype, as FlashInfer casts topk_weights
249+
# to int32. Can be removed after the bug is fixed.
250+
topk_weights = topk_weights.view(topk_weights_dtype)
247251

248252
x, x_sf = moe_kernel_quantize_input(
249253
x,

0 commit comments

Comments
 (0)