Skip to content

Commit d43ad5a

Browse files
[BugFix] Fix DCP Assert (AssertionError: DCP not support reorder_batch_threshold > 1 now.) (#28100)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 0ff05e3 commit d43ad5a

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def __init__(
545545
vllm_config: VllmConfig,
546546
device: torch.device,
547547
metadata_cls: type[M] | None = None,
548+
supports_dcp_with_varlen: bool = False,
548549
):
549550
self.metadata_cls = (
550551
metadata_cls if metadata_cls is not None else MLACommonMetadata
@@ -638,7 +639,7 @@ def __init__(
638639

639640
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
640641
self._init_reorder_batch_threshold(
641-
self.reorder_batch_threshold, supports_spec_decode
642+
self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
642643
)
643644

644645
# Validate consistency between query_len_support and reorder_batch_threshold

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,12 @@ def __init__(
8181
device: torch.device,
8282
):
8383
super().__init__(
84-
kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata
84+
kv_cache_spec,
85+
layer_names,
86+
vllm_config,
87+
device,
88+
FlashAttnMLAMetadata,
89+
supports_dcp_with_varlen=True,
8590
)
8691
self.max_num_splits = 0 # No upper bound on the number of splits.
8792
self.fa_aot_schedule = get_flash_attn_version() == 3

vllm/v1/attention/backends/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,10 @@ def __init__(
264264
self.device = device
265265

266266
def _init_reorder_batch_threshold(
267-
self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False
267+
self,
268+
reorder_batch_threshold: int = 1,
269+
supports_spec_as_decode: bool = False,
270+
supports_dcp_with_varlen: bool = False,
268271
) -> None:
269272
self.reorder_batch_threshold = reorder_batch_threshold
270273
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
@@ -281,6 +284,12 @@ def _init_reorder_batch_threshold(
281284
1 + speculative_config.num_speculative_tokens,
282285
)
283286

287+
if (
288+
self.vllm_config.parallel_config.decode_context_parallel_size > 1
289+
and not supports_dcp_with_varlen
290+
):
291+
self.reorder_batch_threshold = 1
292+
284293
@abstractmethod
285294
def build(
286295
self,

0 commit comments

Comments
 (0)