Skip to content

Conversation

@bnellnm
Copy link
Collaborator

@bnellnm bnellnm commented Oct 17, 2025

Purpose

Make a new FusedMoEModularMethod subclass of FusedMoeMethodBase for use with modular kernels.

Instead of having every subclass of FusedMoEMethodBase check self.fused_experts, we swap out the quant_method of the FusedMoE layer to an instance of FusedMoEModularMethod. This will reduce the complexity of the various FusedMoEMethodBase subclass apply methods and isolate uses of modular kernels to the new class.

Test Plan

Ran by hand on some fp8 + modelopt models.
CI tests

Test Result

cc @varun-sundar-rabindranath , @wenscarl

@mergify
Copy link

mergify bot commented Oct 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 17, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

assert self.quant_method is not None
return (
self.quant_method.fused_experts is not None
and self.quant_method.fused_experts.output_is_reduced()

P0 Badge Accessing missing fused_experts attribute

The commit removes fused_experts from FusedMoEMethodBase, but FusedMoE still unconditionally accesses self.quant_method.fused_experts here (and again later when staging tokens). When the quant method does not use a modular kernel—e.g. AWQ, BitsAndBytes, RTN—init_prepare_finalize now leaves the original quant method in place and it no longer defines a fused_experts attribute. These checks will therefore raise AttributeError before any routing happens. The guard should use hasattr or using_modular_kernel instead of dereferencing the attribute directly.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the handling of modular kernels for Fused MoE layers by introducing a FusedMoEModularMethod wrapper. This is a good simplification that centralizes logic. However, I've identified two critical issues that could lead to runtime errors. One is related to an incorrect condition for EPLB support in the FP8 quantization method, and the other is an incorrect API usage for submodule replacement. I have provided detailed comments and suggestions to address these issues.

@mergify mergify bot removed the needs-rebase label Oct 18, 2025
@bnellnm bnellnm changed the title [Kernels] Swap quant method [Kernels] Isolate modular kernel code from FusedMoEMethodBase subclasses. Oct 20, 2025
@mergify
Copy link

mergify bot commented Oct 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@varun-sundar-rabindranath
Copy link
Contributor

Thanks @bnellnm . This cleans up a bunch of redundant code 🙌 .

I have a suggestion. IIUC, the function call chain for the construction of FusedMoEModularMethod looks something like follows,

1. `DeviceCommunicatorBase::prepare_communication_buffer_for_model()` calls  `FusedMoE::init_prepare_finalize()` 
2. `FusedMoE::init_prepare_finalize()` calls `FusedMoEMethodBase::init_prepare_finalize()` and it returns the `FusedMoEModularKernel` object
3. `FusedMoE::init_prepare_finalize()` then makes a `FusedMoEModularMethod` object and overrides its `self.quant_method`

Here, note that FusedMoEMethodBase::init_prepare_finalize() calls FusedMoEMethodBase::maybe_make_prepare_finalize() which in turn calls a static function FusedMoEMethodBase::_maybe_make_prepare_finalize() which does most of the work anyways.

My suggestion is to move FusedMoEModularMethod into its own file and expose a function say maybe_make_fused_moe_modular_method() that attempts to construct the FusedMoEModularMethod object.

That way, we can get rid of most of the ModularKernel specific code from fused_moe/layer.py and have it in a different file thus cleaning up fused_moe/layer.py greatly.

What do you think ?

I am not suggesting we do it in this PR. I can take it up as well 👍

@bnellnm
Copy link
Collaborator Author

bnellnm commented Oct 24, 2025

Thanks @bnellnm . This cleans up a bunch of redundant code 🙌 .

I have a suggestion. IIUC, the function call chain for the construction of FusedMoEModularMethod looks something like follows,

1. `DeviceCommunicatorBase::prepare_communication_buffer_for_model()` calls  `FusedMoE::init_prepare_finalize()` 
2. `FusedMoE::init_prepare_finalize()` calls `FusedMoEMethodBase::init_prepare_finalize()` and it returns the `FusedMoEModularKernel` object
3. `FusedMoE::init_prepare_finalize()` then makes a `FusedMoEModularMethod` object and overrides its `self.quant_method`

Here, note that FusedMoEMethodBase::init_prepare_finalize() calls FusedMoEMethodBase::maybe_make_prepare_finalize() which in turn calls a static function FusedMoEMethodBase::_maybe_make_prepare_finalize() which does most of the work anyways.

My suggestion is to move FusedMoEModularMethod into its own file and expose a function say maybe_make_fused_moe_modular_method() that attempts to construct the FusedMoEModularMethod object.

That way, we can get rid of most of the ModularKernel specific code from fused_moe/layer.py and have it in a different file thus cleaning up fused_moe/layer.py greatly.

What do you think ?

I am not suggesting we do it in this PR. I can take it up as well 👍

Yeah, that's a good idea. I was also considering splitting up layer.py in different ways, e.g. move UnquantizedMoEMethod to a separate file.

I'd rather do that in a separate PR though.

if layer.w2_weight is None
else layer.w2_weight
)
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this setting of layer.w13_weight and layer.w2_weight better fits in the process_weights_after_loading function here


That way we can get rid of having to differentiate between the w13_weight_triton_tensor/w2_weight_triton_tensor and w13_weight/w2_weight .

Not suggesting for this PR. Fixing that I think should be its own PR.

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Very nice cleanups ! Thanks @bnellnm

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 30, 2025
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
@mgoin mgoin merged commit 938772a into vllm-project:main Nov 4, 2025
61 checks passed
@wangshangsam
Copy link
Collaborator

wangshangsam commented Nov 6, 2025

@bnellnm I have a ... maybe dumb ... question - how exactly is each derived MoEMethod class going to trigger FusedMoEModularMethod.apply() (thereby using the modular kernels)? Doesn't each subclass override the .apply() completely?

@bnellnm
Copy link
Collaborator Author

bnellnm commented Nov 6, 2025

@bnellnm I have a ... maybe dumb ... question - how exactly is each derived MoEMethod class going to trigger FusedMoEModularMethod.apply() (thereby using the modular kernels)? Doesn't each subclass override the .apply() completely?

The FusedMoE layer calls self.quant_method.apply so if no modular kernel has been constructed, this will invoke an apply method on some subclass of FusedMoEMethodBase. Now, when a modular kernel gets created, the FusedMoE layer will swap out self.quant_method with an instance of FusedMoEModularMethod which will call the modular kernel instead.

So, subclasses of FusedMoEMethodBase no longer need to worry about modifying apply for modular kernels.

juliendenize pushed a commit to juliendenize/vllm that referenced this pull request Nov 6, 2025
zWaNg3 added a commit to fangyuchu/vllm that referenced this pull request Nov 7, 2025
* add fault_report_addr in FaultToleranceConfig

* add handle fault&get_fault_info api

Signed-off-by: w00689259 <[email protected]>

* remove fault_report_address in CoreEngineActorManager __init__

Signed-off-by: a798347923 <[email protected]>

* ruff format

Signed-off-by: a798347923 <[email protected]>

* add handle fault&get_fault_info api

Signed-off-by: w00689259 <[email protected]>

* fix one bug.

Signed-off-by: fangyuchu <[email protected]>

* add fault_report_port in FaultToleranceConfig

Signed-off-by: a798347923 <[email protected]>

* add zmq_addr concatenate with fault_report_addr and fault_report_port

Signed-off-by: a798347923 <[email protected]>

* fault reporter bug fix

Signed-off-by: w00689259 <[email protected]>

* fault reporter bug fix

Signed-off-by: w00689259 <[email protected]>

* fault reporter bug fix

Signed-off-by: w00689259 <[email protected]>

* fault reporter bug fix

Signed-off-by: w00689259 <[email protected]>

* fault reporter bug fix

Signed-off-by: w00689259 <[email protected]>

* fault reporter bug fix

Signed-off-by: w00689259 <[email protected]>

* fix some bug

* fault reporter bug fix

Signed-off-by: w00689259 <[email protected]>

* fault reporter bug fix

Signed-off-by: w00689259 <[email protected]>

* remove fault_report_addr in FaultToleranceConfig

Signed-off-by: a798347923 <[email protected]>

* refactor: relocate method serialization functions to serial_util.py

Signed-off-by: fangyuchu <[email protected]>

* fix actor bug

* fix actor bug

* add engine_core_cmd_addr in FaultToleranceConfig

Signed-off-by: a798347923 <[email protected]>

* add and use _stop_worker_execution in EngineCoreGuard

Signed-off-by: a798347923 <[email protected]>

* add and use run in WorkerGuard

Signed-off-by: a798347923 <[email protected]>

* fix actor bug

* fix bug

* fix sentinel

* fix bug vllm/v1/engine/core.py:847: error: Missing positional argument "tp_size" in call to "EngineCoreGuard"

Signed-off-by: a798347923 <[email protected]>

* fix bug error: Missing positional arguments "length", "byteorder" in call to "to_bytes" of "int"

Signed-off-by: a798347923 <[email protected]>

* fix bug in fault tolerance mode

Signed-off-by: w00689259 <[email protected]>

* fix bug in fault tolerance mode

Signed-off-by: w00689259 <[email protected]>

* change fault_report_port to internal_fault_report_port
add external_fault_notify_port

Signed-off-by: a798347923 <[email protected]>

* change fault_report_port to internal_fault_report_port
add external_fault_notify_port

Signed-off-by: a798347923 <[email protected]>

* add _recv_cmd func
use deserialize_method_call and run_method in run func

Signed-off-by: a798347923 <[email protected]>

* Update core.py

fix bug error: Need type annotation for "kwargs" (hint: "kwargs: dict[<type>, <type>] = ...")

Signed-off-by: a798347923 <[email protected]>

* add self.ctx.term() in shutdown()

Signed-off-by: a798347923 <[email protected]>

* changed import deserialize_method_call,serialize_method_call

Signed-off-by: a798347923 <[email protected]>

* changed init worker_guard in init_device

Signed-off-by: a798347923 <[email protected]>

* Update core.py

add import serialize_method_call

Signed-off-by: a798347923 <[email protected]>

* Update gpu_worker.py

changed init WorkerGuard in init_device

Signed-off-by: a798347923 <[email protected]>

* Update gpu_worker.py

FIX BUG self.worker_guard: WorkerGuard|None = None

Signed-off-by: a798347923 <[email protected]>

* Update gpu_worker.py

fix bug error: Argument 1 to "deserialize_method_call" has incompatible type "str | None"; expected "str"  [arg-type]

Signed-off-by: a798347923 <[email protected]>

* Update gpu_worker.py

ruff format

Signed-off-by: a798347923 <[email protected]>

* Update core.py

ruff-format

Signed-off-by: a798347923 <[email protected]>

* actively send exception information

Signed-off-by: w00689259 <[email protected]>

* actively send exception information

Signed-off-by: w00689259 <[email protected]>

* actively send exception information

Signed-off-by: w00689259 <[email protected]>

* change engine_core_cmd_addr(str) to engine_core_cmd_addrs(list[str]) in EngineZmqAddresses

Signed-off-by: a798347923 <[email protected]>

* change engine_core_cmd_addr(str) to engine_core_cmd_addrs(list[str]) in EngineZmqAddresses

Signed-off-by: a798347923 <[email protected]>

* Update utils.py

delete engine_core_cmd_addr in EngineZmqAddresses

Signed-off-by: a798347923 <[email protected]>

* Remove redundant configuration: fault-pub-port

Signed-off-by: fangyuchu <[email protected]>

* Send pause instructions after receiving fault info in ClientGuard

Signed-off-by: fangyuchu <[email protected]>

* change engine_core_guard_identities from dict[int, bytes] to list[bytes]

Signed-off-by: a798347923 <[email protected]>

* fix bug "only the worker guard of engine core 0 can receive messages sent from engine core guard

Signed-off-by: a798347923 <[email protected]>

* change local_rank to rank_in_group in WorkerGuard

Signed-off-by: a798347923 <[email protected]>

* changed del self.client_cmd_registry[int(unhealthy_engine.engine_id)]

Signed-off-by: a798347923 <[email protected]>

* add gloo communication timeout

* fix some bug

* add  stateless_process_group gloo_comm_timeout

* reconstruct fault receiver&fault handler

Signed-off-by: w00689259 <[email protected]>

* fix some bug

* reconstruct fault receiver&fault handler

Signed-off-by: w00689259 <[email protected]>

* reconstruct fault receiver&fault handler

Signed-off-by: w00689259 <[email protected]>

* fix return format

Signed-off-by: w00689259 <[email protected]>

* fix return format

Signed-off-by: w00689259 <[email protected]>

* fix return format

Signed-off-by: w00689259 <[email protected]>

* add abort request

* fix some bug

* fix some bug

* fix some bug

* add dt for client guard

Signed-off-by: w00689259 <[email protected]>

* add dt for client guard

Signed-off-by: w00689259 <[email protected]>

* add dt for client guard

Signed-off-by: w00689259 <[email protected]>

* Implementation of two types of pause: a soft one by using flag signals and a hard one by aborting nccl communicators.

Signed-off-by: fangyuchu <[email protected]>

* Refine certain log forms and fix a minor bug in pause function.

Signed-off-by: fangyuchu <[email protected]>

* Refactor and abstract the recv_msg logic in CG,ECG,WG.

Signed-off-by: fangyuchu <[email protected]>

* [Frontend] Align finish_reason when tool is called with OpenAI (vllm-project#25054)

Signed-off-by: Sungyoon Jeong <[email protected]>
Co-authored-by: Chauncey <[email protected]>

* [Hybrid] Pass kernel block size to builders (vllm-project#27753)

Signed-off-by: Thomas Parnell <[email protected]>

* [Bugfix] Padded Eagle Specdec with Chunked Prefill (vllm-project#26263)

Signed-off-by: Rémi Delacourt <[email protected]>
Signed-off-by: Rémi Delacourt <[email protected]>
Signed-off-by: remi <[email protected]>
Co-authored-by: Benjamin Chislett <[email protected]>

* [XPU]Refine Dockerfile.xpu, avoid oneccl dependency issue (vllm-project#27964)

Signed-off-by: Kunshang Ji <[email protected]>

* Add and check method uuid when sending commands and receiving results.

Signed-off-by: fangyuchu <[email protected]>

* Add ORCA endpoint load metrics support (vllm-project#24905)

Signed-off-by: Misha Efimov <[email protected]>

* [CI/Build] Remove the flaky gpt-oss lora test (vllm-project#27966)

Signed-off-by: Jee Jee Li <[email protected]>

* Abstract the logic of sending instructions and waiting responses from FaultHandler

Signed-off-by: fangyuchu <[email protected]>

* [Model] Add PaddleOCR-VL Model Support  (vllm-project#27758)

Signed-off-by: zhangyue <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: zhangyue66 <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Isotr0py <[email protected]>

* Add options in EngineCoreGuard to recv execution results from WorkerGuard

Signed-off-by: fangyuchu <[email protected]>

* Early exit for MoE LoRA kernels (vllm-project#27131)

Signed-off-by: gnovack <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>

* [Bugfix] Skip gs:// model paths for speculator detection (vllm-project#27846)

Signed-off-by: Peter Schuurman <[email protected]>

* [BUG] Make 'binary' default option for saving torch compile artifacts when using standalone_compile (vllm-project#27616)

Signed-off-by: ahao-anyscale <[email protected]>

* [CI/Testing] Add basic single node dual batch overlap test (vllm-project#27235)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Spec Decode] Integrate Suffix Decoding from Arctic Inference (vllm-project#25784)

Co-authored-by: Aurick Qiao <[email protected]>

* [Feature][Benchmarks] Support `inf` burstiness (vllm-project#26941)

Signed-off-by: Sophie du Couédic <[email protected]>

* [Bugfix][Qwen][Multimodal] Move Qwen2_5_vl sdpa to custom op and reenable compile (vllm-project#27764)

Signed-off-by: Lucas Kabela <[email protected]>

* [Bugfix] change FlashMLA reorder_batch_threshold (vllm-project#27777)

Signed-off-by: Matthew Bonanni <[email protected]>

* [Docs] add runai_streamer_sharded to LoadConfig (vllm-project#27937)

Signed-off-by: Andy Xie <[email protected]>

* Add TP parameter to attention tests (vllm-project#27683)

Signed-off-by: Matthew Bonanni <[email protected]>

* [Bugfix][plugin] fla crash on plugin (vllm-project#27322)

* [Bugfix] Fix MoE Routing Simulation (vllm-project#28002)

Signed-off-by: Tyler Michael Smith <[email protected]>

* Remove the tpu docker image nightly build. (vllm-project#27997)

Signed-off-by: Qiliang Cui <[email protected]>

* [Bugfix][ROCm] Fix ViT rotary embeddings for torch.compile compatibility on ROCm (vllm-project#27748)

Signed-off-by: vllmellm <[email protected]>

* [LoRA] Lora shrink swizzle (vllm-project#27694)

Signed-off-by: li2haipeng <[email protected]>
Signed-off-by: Haipeng Li <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>

* [Refactor] Lazy import tool_parser (vllm-project#27974)

Signed-off-by: chaunceyjiang <[email protected]>

* [NIXL][XPU] Pin NIXL version to 0.7.0 (vllm-project#27849)

Signed-off-by: zhenwei-intel <[email protected]>

* [Metrics] Enable sleep state metric outside of dev mode (vllm-project#27867)

Signed-off-by: Mark McLoughlin <[email protected]>

* [Bug] Batch invariant: Fix flash attn MLA `RuntimeError: scheduler_metadata must have shape (metadata_size)` (vllm-project#27884)

* [CPU]Improve dynamic 4bit moe performance (vllm-project#27240)

Signed-off-by: Zhang Xiangze <[email protected]>

* [CI/Build] Update LM Eval Version in AMD CI (vllm-project#27944)

Signed-off-by: zhewenli <[email protected]>

* [KV Connector] Make KVCacheConfig an explicit constructor argument (vllm-project#27887)

Signed-off-by: Mark McLoughlin <[email protected]>

* [Model] fix ernie45 reasoning_parser (vllm-project#27973)

Signed-off-by: wangyafeng <[email protected]>

* [CI/Build] Fix OpenAI API correctness on AMD CI (vllm-project#28022)

Signed-off-by: zhewenli <[email protected]>

* [BugFix][Performance] Restore flashinfer autotuning for all scenarios (vllm-project#27904)

* Support worker reinitialization after hard pause; add task queue in FaultHandler to ensure sequential task execution

Signed-off-by: fangyuchu <[email protected]>

* resolve conflicts

Signed-off-by: w00689259 <[email protected]>

* resolve conflicts

Signed-off-by: w00689259 <[email protected]>

* resolve conflicts

Signed-off-by: w00689259 <[email protected]>

* Load tuned fused_moe_lora shrink and expand kernel configs separately (vllm-project#27435)

Signed-off-by: Yu Gong <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>

* resolve conflicts

Signed-off-by: w00689259 <[email protected]>

* resolve conflicts

Signed-off-by: w00689259 <[email protected]>

* resolve conflicts

Signed-off-by: w00689259 <[email protected]>

* Support using Int4PreshuffledTensor after loading (vllm-project#26066)

Signed-off-by: Jerry Zhang <[email protected]>

* [Core] Enable StatLogger in LLMEngine (vllm-project#28020)

Signed-off-by: Zhuohan Li <[email protected]>

* [Model][Bugfix] fix pipeline parallelism support for NemotronH (vllm-project#27968)

Signed-off-by: Tomer Asida <[email protected]>

* [Model] add optimal triton fused moe configs for NemotronH MoE (vllm-project#27967)

Signed-off-by: Tomer Asida <[email protected]>

* [Kernels] Isolate modular kernel code from FusedMoEMethodBase subclasses. (vllm-project#27123)

* [BugFix] Fix incorrect preallocated sampled_token_ids tensor size (vllm-project#28025)

Signed-off-by: Nick Hill <[email protected]>

* [Perf] SM100 - add swap AB optimization to CUTLASS FP8 GEMM (vllm-project#27284)

Signed-off-by: Faqin Zhong <[email protected]>
Co-authored-by: Faqin Zhong <[email protected]>
Co-authored-by: Michael Goin <[email protected]>

* [PERF] Decouple projections from GDN custom op (vllm-project#27512)

Signed-off-by: Vadim Gimpelson <[email protected]>

* [model] Add support for openPangu_Ultra_MoE (vllm-project#27521)

Signed-off-by: yuantao <[email protected]>
Signed-off-by: yt0428 <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>

* [PerfFix] Avoid separate thread for MP executor shm spin (vllm-project#28012)

Signed-off-by: Nick Hill <[email protected]>

* [AsyncScheduling] Don't schedule past request max_tokens (vllm-project#27922)

Signed-off-by: Nick Hill <[email protected]>

* Remove deprecated `--rope-scaling` and `--rope-theta` (vllm-project#28006)

Signed-off-by: Harry Mellor <[email protected]>

* [ROCm][Perf] New design on ROCm AITER MHA backend Implementation (vllm-project#25763)

Signed-off-by: ganyi <[email protected]>

* Added disable rule to track files under benchmarks/lib (vllm-project#28048)

Signed-off-by: Nadav Kluger <[email protected]>

* [Multimodal] Make MediaConnector extensible. (vllm-project#27759)

Signed-off-by: Chenheli Hua <[email protected]>

* [ROCm] gemm_a16w16 upstreaming (vllm-project#26969)

Signed-off-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>

* Revert "[PERF] Decouple projections from GDN custom op" (vllm-project#28080)

Signed-off-by: Vadim Gimpelson <[email protected]>

* add engine core ut

Signed-off-by: w00689259 <[email protected]>

* add engine core ut

Signed-off-by: w00689259 <[email protected]>

* [Qwen3-Next] MOE configs for A100-SXM4-80GB TP4 TP8 (vllm-project#27740)

* [XPU] Add gpt-oss model support for Intel GPU (vllm-project#27786)

Signed-off-by: Kunshang Ji <[email protected]>

* [CI/Build] Enable some fixed tests in AMD CI (vllm-project#28078)

Signed-off-by: zhewenli <[email protected]>

* [V0 deprecation] Remove VLLM_USE_V1 usage in most modules (vllm-project#27955)

Signed-off-by: wangxiyuan <[email protected]>

* [Bugfix] Fix encoder-only model support for transformers backend (vllm-project#28021)

Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>

* [BugFix] Fix DCP Assert (AssertionError: DCP not support reorder_batch_threshold > 1 now.) (vllm-project#28100)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Model, Core] Support Granite Speech & LoRA for STT (vllm-project#24455)

* [Refactor] Lazy-loaded reasoning_parser (vllm-project#28092)

Signed-off-by: chaunceyjiang <[email protected]>

* [Refactor] to simplify and extract the shared logic between chat completion and responses (vllm-project#27961)

Signed-off-by: chaunceyjiang <[email protected]>

* [bugfix] fix wrong `dcp_local_seq_lens` calc (vllm-project#27518)

Signed-off-by: Qiu <[email protected]>

* [Hybrid allocator + kv connector] revert connector test changes related to hybrid allocator (vllm-project#28011)

Signed-off-by: KuntaiDu <[email protected]>

* [Misc] fix import error for DeepSeekR1ReasoningParser (vllm-project#28114)

Signed-off-by: chaunceyjiang <[email protected]>

* Fix excessive logging noise by reducing the log level of the MinimaxM2ToolParser import success message (vllm-project#27635)

Signed-off-by: minatoaquaMK2 <[email protected]>

* Bugfix: Cutlass FP8 FusedMoE bad scaling factors (vllm-project#27255)

Signed-off-by: Amir Klein <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Michael Goin <[email protected]>

* [Graph Partition][Cache] Use inductor partition ops config (vllm-project#27702)

Signed-off-by: Boyuan Feng <[email protected]>

* [XPU] Enable custom routing functions in IPEX for Llama4 (vllm-project#28004)

Signed-off-by: frost-intel <[email protected]>

* add kimi reasoning parser (vllm-project#28128)

Signed-off-by: wangzhengtao <[email protected]>
Co-authored-by: wangzhengtao <[email protected]>

* [DCP] check return_lse for all layers in dcp (vllm-project#27929)

Signed-off-by: Chen Zhang <[email protected]>

* [BugFix] Support EP/DP + EPLB with MTP (vllm-project#25311)

Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Co-authored-by: Sage Moore <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>

* Enabling cooperative multi-gpu tests on multi-gpu nodes (vllm-project#27986)

Signed-off-by: Alexei V. Ivanov <[email protected]>

* [ROCm][MLA] Support block-size > 1 for AITER MLA backend  (vllm-project#27224)

Signed-off-by: ganyi <[email protected]>
Co-authored-by: wuhuikx <[email protected]>

* [Bugfix] Validate custom logits processor xargs for online serving (vllm-project#27560)

Signed-off-by: Isotr0py <[email protected]>

* [misc] add vLLM Beijing Meetup (vllm-project#28127)

Signed-off-by: Jiaju Zhang <[email protected]>

* [Kernel] Fuse computation of g and beta for Gated Delta Net (vllm-project#28095)

Signed-off-by: zjy0516 <[email protected]>

* [Core] add support for reasoning parser plugins (vllm-project#28075)

Signed-off-by: walter beller-morales <[email protected]>

* [Bugfix] vLLM should check Inductor config for compile cache enablement status (vllm-project#27637)

Signed-off-by: Yanan Cao <[email protected]>

* [FlashInfer] Avoid FlashInfer block_size 16 + head_size 256 on blackwell (vllm-project#27994)

Signed-off-by: Chen Zhang <[email protected]>

* [CI]: Add LMCacheConnector Unit Tests (vllm-project#27852)

Signed-off-by: Samuel Shen <[email protected]>
Co-authored-by: Samuel Shen <[email protected]>
Co-authored-by: Yihua Cheng <[email protected]>

* [Feature] Extend batch invariant torch.compile to B200 (vllm-project#27856)

Signed-off-by: PaulZhang12 <[email protected]>

* [Bugfix] Fix Qwen3-Reranker-8B load (vllm-project#28117)

Signed-off-by: wang.yuqi <[email protected]>

* [Docs] Clean up README_TUNING.md (vllm-project#28088)

Signed-off-by: windsonsea <[email protected]>

* [Hardware][IBM Z] Optimize s390x Dockerfile (vllm-project#28023)

Signed-off-by: Rehan Khan <[email protected]>

* [Chore] Remove Nemotron-Nano-VL config copy (vllm-project#28126)

Signed-off-by: Isotr0py <[email protected]>

* [Docs] Add guide to debugging vLLM-torch.compile integration (vllm-project#28094)

Signed-off-by: Richard Zou <[email protected]>

* [Feature]: Add corrupted request metric to V1 metrics system. (vllm-project#27306)

Signed-off-by: atalhens <[email protected]>

* [CI/Build] Update checking logic in cutlass_group_gemm_supported  (vllm-project#27948)

Signed-off-by: zhewenli <[email protected]>

* [CI/Build] Fix `test_defaults_with_usage_context` in AMD CI (vllm-project#27926)

Signed-off-by: zhewenli <[email protected]>

* [Core][Hybrid allocator + connector 2/n] Unify `remove_skipped_blocks` by `get_last_useful_token` (vllm-project#25431)

Signed-off-by: KuntaiDu <[email protected]>

* [Debugging] Add annotation for easier trace analysis (vllm-project#22496)

* [PERF] Decouple projections from GDN custom op. Attempt 2 (vllm-project#28083)

Signed-off-by: Vadim Gimpelson <[email protected]>

* [Bug] Fix cpu disable shared_experts `VLLM_DISABLE_SHARED_EXPERTS_STREAM` (vllm-project#28157)

Signed-off-by: yewentao256 <[email protected]>

* [Bug] Fix env string `"0"` same to `True` (vllm-project#28159)

Signed-off-by: yewentao256 <[email protected]>

* Ensure WorkerGuard command execution returns result; fix missing set_device when TP>1

Signed-off-by: fangyuchu <[email protected]>

* [Feature] Enable TP + EP `shared_experts` overlap with router, 3.7% E2E performance improvement (vllm-project#28164)

Signed-off-by: yewentao256 <[email protected]>

* [CI Failure] `nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV` was removed from HF. Skip it in tests (vllm-project#28170)

Signed-off-by: Vadim Gimpelson <[email protected]>

* [Misc] Remove the duplicate code (vllm-project#28111)

Signed-off-by: chaunceyjiang <[email protected]>

* rename& format logger

Signed-off-by: w00689259 <[email protected]>

* rename& format logger

Signed-off-by: w00689259 <[email protected]>

* feat(nccl): enable non-blocking NCCL communicators to support ncclCommAbort

Signed-off-by: fangyuchu <[email protected]>

---------

Signed-off-by: w00689259 <[email protected]>
Signed-off-by: a798347923 <[email protected]>
Signed-off-by: fangyuchu <[email protected]>
Signed-off-by: zWaNg3 <[email protected]>
Signed-off-by: a798347923 <[email protected]>
Signed-off-by: Sungyoon Jeong <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Rémi Delacourt <[email protected]>
Signed-off-by: Rémi Delacourt <[email protected]>
Signed-off-by: remi <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Misha Efimov <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: zhangyue <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: zhangyue66 <[email protected]>
Signed-off-by: gnovack <[email protected]>
Signed-off-by: Peter Schuurman <[email protected]>
Signed-off-by: ahao-anyscale <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Sophie du Couédic <[email protected]>
Signed-off-by: Lucas Kabela <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Andy Xie <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Qiliang Cui <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: li2haipeng <[email protected]>
Signed-off-by: Haipeng Li <[email protected]>
Signed-off-by: chaunceyjiang <[email protected]>
Signed-off-by: zhenwei-intel <[email protected]>
Signed-off-by: Mark McLoughlin <[email protected]>
Signed-off-by: Zhang Xiangze <[email protected]>
Signed-off-by: zhewenli <[email protected]>
Signed-off-by: wangyafeng <[email protected]>
Signed-off-by: Yu Gong <[email protected]>
Signed-off-by: Jerry Zhang <[email protected]>
Signed-off-by: Zhuohan Li <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Faqin Zhong <[email protected]>
Signed-off-by: Vadim Gimpelson <[email protected]>
Signed-off-by: yuantao <[email protected]>
Signed-off-by: yt0428 <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: Nadav Kluger <[email protected]>
Signed-off-by: Chenheli Hua <[email protected]>
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: Qiu <[email protected]>
Signed-off-by: KuntaiDu <[email protected]>
Signed-off-by: minatoaquaMK2 <[email protected]>
Signed-off-by: Amir Klein <[email protected]>
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: frost-intel <[email protected]>
Signed-off-by: wangzhengtao <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Alexei V. Ivanov <[email protected]>
Signed-off-by: Jiaju Zhang <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: walter beller-morales <[email protected]>
Signed-off-by: Yanan Cao <[email protected]>
Signed-off-by: Samuel Shen <[email protected]>
Signed-off-by: PaulZhang12 <[email protected]>
Signed-off-by: wang.yuqi <[email protected]>
Signed-off-by: windsonsea <[email protected]>
Signed-off-by: Rehan Khan <[email protected]>
Signed-off-by: Richard Zou <[email protected]>
Signed-off-by: atalhens <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Co-authored-by: fangyuchu <[email protected]>
Co-authored-by: a798347923 <[email protected]>
Co-authored-by: w00689259 <[email protected]>
Co-authored-by: fangyuchu <[email protected]>
Co-authored-by: TianZhuo <[email protected]>
Co-authored-by: a798347923 <[email protected]>
Co-authored-by: Sungyoon Jeong <[email protected]>
Co-authored-by: Chauncey <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Co-authored-by: Rémi Delacourt <[email protected]>
Co-authored-by: Benjamin Chislett <[email protected]>
Co-authored-by: Kunshang Ji <[email protected]>
Co-authored-by: Misha Efimov <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Co-authored-by: zhang-prog <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: gnovack <[email protected]>
Co-authored-by: pwschuurman <[email protected]>
Co-authored-by: ahao-anyscale <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Aurick Qiao <[email protected]>
Co-authored-by: Aurick Qiao <[email protected]>
Co-authored-by: Sophie du Couédic <[email protected]>
Co-authored-by: Lucas Kabela <[email protected]>
Co-authored-by: Matthew Bonanni <[email protected]>
Co-authored-by: Ning Xie <[email protected]>
Co-authored-by: Hank_ <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: QiliangCui <[email protected]>
Co-authored-by: vllmellm <[email protected]>
Co-authored-by: li2haipeng <[email protected]>
Co-authored-by: liuzhenwei <[email protected]>
Co-authored-by: Mark McLoughlin <[email protected]>
Co-authored-by: Wentao Ye <[email protected]>
Co-authored-by: xiangze-arm <[email protected]>
Co-authored-by: Zhewen Li <[email protected]>
Co-authored-by: CSWYF3634076 <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: yugong333 <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: tomeras91 <[email protected]>
Co-authored-by: bnellnm <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Co-authored-by: lyrisz <[email protected]>
Co-authored-by: Faqin Zhong <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Vadim Gimpelson <[email protected]>
Co-authored-by: yt0428 <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Pleaplusone <[email protected]>
Co-authored-by: nadavkluger <[email protected]>
Co-authored-by: Chenheli Hua <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: tou <[email protected]>
Co-authored-by: wangxiyuan <[email protected]>
Co-authored-by: Alex Brooks <[email protected]>
Co-authored-by: Qiu <[email protected]>
Co-authored-by: Kuntai Du <[email protected]>
Co-authored-by: Eric Yue <[email protected]>
Co-authored-by: amirkl94 <[email protected]>
Co-authored-by: Boyuan Feng <[email protected]>
Co-authored-by: Frost Mitchell <[email protected]>
Co-authored-by: bigmoyan <[email protected]>
Co-authored-by: wangzhengtao <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: Ilya Markov <[email protected]>
Co-authored-by: Sage Moore <[email protected]>
Co-authored-by: Alexei-V-Ivanov-AMD <[email protected]>
Co-authored-by: wuhuikx <[email protected]>
Co-authored-by: Jiaju Zhang <[email protected]>
Co-authored-by: Jiangyun Zhu <[email protected]>
Co-authored-by: Walter Beller-Morales <[email protected]>
Co-authored-by: gmagogsfm <[email protected]>
Co-authored-by: Samuel Shen <[email protected]>
Co-authored-by: Samuel Shen <[email protected]>
Co-authored-by: Yihua Cheng <[email protected]>
Co-authored-by: Paul Zhang <[email protected]>
Co-authored-by: wang.yuqi <[email protected]>
Co-authored-by: Michael Yao <[email protected]>
Co-authored-by: R3hankhan <[email protected]>
Co-authored-by: Richard Zou <[email protected]>
Co-authored-by: Snehlata <[email protected]>
Co-authored-by: Dayeol Lee <[email protected]>
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
@bnellnm bnellnm deleted the swap-quant-method branch November 11, 2025 19:43
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Nov 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants