Skip to content

Conversation

@peakcrosser7
Copy link

@peakcrosser7 peakcrosser7 commented Oct 31, 2025

Purpose

This PR fixes the bug reported in issue #27655 .

There are two main changes:

  1. Fix AttributeError with flashinfer_all2allv backend:
    When using the flashinfer_all2allv backend, an AttributeError: "'FlashInferAllToAllManager' object has no attribute 'prepare_workspace'" is raised. This was caused by an incorrect method call to prepare_workspace instead of the correct prepare_workspace_tensor in the FlashInferAllToAllManager. This PR corrects the method name.

  2. Retain original dtype for topk_weights:
    Additionally, there appears to be a potential bug in FlashInfer where calling MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather() incorrectly changes the data type of topk_weights from torch.float32 to torch.int32. This is likely related to the prepared_local_scales tensor being allocated as torch.int32 within FlashInfer's moe_prepare() function.
    As a temporary workaround, this PR saves the original data type of topk_weights before the function call and restores it immediately after, ensuring the tensor's integrity for subsequent operations.

Test Plan

#!/bin/bash

PORT=8235
NWORKERS=8
USE_DP=1
MAX_MODEL_LEN=40960
DO_NSYS=0


DP=1
TP=$NWORKERS
if (( USE_DP == 1 )); then
    DP=$NWORKERS
    TP=1
fi

MODEL_DIR=/root/workspace/models/Qwen3-235B-A22B-Instruct-2507-NVFP4

echo "MODEL_DIR: $MODEL_DIR"

env_vars=(
    # For resolving bugs
    "NCCL_SOCKET_IFNAME=eth0"
    "GLOO_SOCKET_IFNAME=eth0"
    "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False"
    # For vLLM
    "VLLM_USE_FLASHINFER_MOE_FP4=1"
    "VLLM_FLASHINFER_MOE_BACKEND=throughput"
    "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB={\"2\":32,\"4\":32,\"8\":8}"
)

for var in "${env_vars[@]}"; do
    var_name="${var%%=*}"
    var_value="${var#*=}"
    echo -e "\t$var_name=$var_value"
done

CMD=( env )
for var in "${env_vars[@]}"; do
    CMD+=( "$var" )
done
CMD+=(
    vllm serve
    $MODEL_DIR
    --port "$PORT"
    --gpu-memory-utilization 0.9
    -dp $DP
    -tp $TP
    --enable-expert-parallel
    --no-enable-prefix-caching
    --enable-chunked-prefill
    --all2all-backend flashinfer_all2allv
    --max-num-seqs 1024
    --kv-cache-dtype fp8
    --async-scheduling
    --max-num-batched-tokens 8192
    --max-model-len $MAX_MODEL_LEN
    --compilation-config "{\"pass_config\":{\"enable_fi_allreduce_fusion\":true,\"enable_attn_fusion\":true,\"enable_noop\":true},\"custom_ops\":[\"+quant_fp8\",\"+rms_norm\"],\"cudagraph_mode\":\"FULL_DECODE_ONLY\",\"splitting_ops\":[]}"
)

echo -e "\nExecuting command:"
printf " %s" "${CMD[@]}"
echo -e "\n"

"${CMD[@]}"

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two important fixes. The first correctly resolves an AttributeError by changing a method call from prepare_workspace to prepare_workspace_tensor, which is a direct and necessary bug fix. The second change addresses a suspected bug in the FlashInfer library where the topk_weights tensor's data type is incorrectly modified. While the intent to restore the original data type is correct, the implementation uses torch.view(), which is unsafe and likely incorrect. I've provided a critical comment to change this to torch.to() to ensure data correctness.

top_k,
)
)
topk_weights = topk_weights.view(topk_weights_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using torch.view() to change the data type is unsafe and likely incorrect. view() reinterprets the underlying bits of the tensor, which means an int32 tensor's data will be interpreted as garbage float32 values (unless the value is zero).

If the FlashInfer function incorrectly returns a tensor with an int32 dtype, the correct way to restore the original floating-point type is to perform a cast using torch.to(). This will convert the integer values back to floating-point values (e.g., 1 becomes 1.0), which is the safe and intended operation.

This is a critical issue as it could lead to silent correctness problems in the model's computations.

Suggested change
topk_weights = topk_weights.view(topk_weights_dtype)
topk_weights = topk_weights.to(topk_weights_dtype)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view() is required here.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

top_k,
)
)
topk_weights = topk_weights.view(topk_weights_dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Restore topk_weights dtype via invalid Tensor.view call

The new code saves topk_weights.dtype and then calls topk_weights = topk_weights.view(topk_weights_dtype) to restore it. In PyTorch Tensor.view only accepts shape arguments; passing a dtype raises TypeError: 'torch.dtype' object cannot be interpreted as an integer, so any invocation of flashinfer_alltoall_dispatch now crashes before communication. If the intent is to cast back to the previous dtype, use .to(topk_weights_dtype) (or equivalent) instead of view.

Useful? React with 👍 / 👎.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor.view() can accept torch.dtype arguments: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.view.html

@peakcrosser7 peakcrosser7 force-pushed the fix/flashinfer_all2all branch 2 times, most recently from e6f8265 to 2b51947 Compare October 31, 2025 06:56
@peakcrosser7 peakcrosser7 changed the title Fixed 'FlashInferAllToAllManager' object has no attribute 'prepare_workspace' [BugFix] Fixed 'FlashInferAllToAllManager' object has no attribute 'prepare_workspace' Oct 31, 2025
@peakcrosser7 peakcrosser7 force-pushed the fix/flashinfer_all2all branch from 2b51947 to e5193b0 Compare October 31, 2025 10:15
@peakcrosser7
Copy link
Author

Hi @mgoin @pavanimajety,

This PR is ready for review. Could you please take a look when you have a moment?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant