Skip to content

Commit 618d877

Browse files
authored
Merge pull request vllm-project#2 from luccafong/mtp_config_enablement
fix mtp config and padding
2 parents 386f9ae + 4273a15 commit 618d877

File tree

5 files changed

+12
-13
lines changed

5 files changed

+12
-13
lines changed

vllm/config/speculative.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def compute_hash(self) -> str:
142142

143143
@staticmethod
144144
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
145-
if hf_config.model_type == "deepseek_v3":
145+
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
146146
hf_config.model_type = "deepseek_mtp"
147147
if hf_config.model_type == "deepseek_mtp":
148148
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
@@ -204,9 +204,8 @@ def __post_init__(self):
204204
# mtp acceleration for more models besides deepseek_v3
205205
if self.target_model_config and \
206206
(self.target_model_config.hf_text_config.model_type \
207-
== "deepseek_v3" or
208-
self.target_model_config.hf_text_config.model_type in
209-
("mimo","ernie4_5_moe", "qwen3_next")):
207+
in ("deepseek_v3", "deepseek_v32",
208+
"mimo","ernie4_5_moe", "qwen3_next")):
210209
# use the draft model from the same model:
211210
self.model = self.target_model_config.model
212211
# Align the quantization of draft model for cases such as

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
5959
config, "index_topk"
6060
)
6161
if self.is_v32:
62-
topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"]
62+
topk_tokens = config.index_topk
6363
topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens,
6464
topk_tokens,
6565
dtype=torch.int32,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,18 +690,18 @@ def sparse_attn_indexer(
690690
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(decode_lens.shape[0], -1, *q_fp8.shape[1:])
691691
# TODO: move and optimize below logic with triton kernels
692692
batch_size = padded_q_fp8_decode_tokens.shape[0]
693+
next_n = padded_q_fp8_decode_tokens.shape[1]
693694
assert batch_size == decode_metadata.seq_lens.shape[0]
695+
num_padded_tokens = batch_size * next_n
694696
logits = fp8_paged_mqa_logits(
695697
padded_q_fp8_decode_tokens,
696698
kv_cache,
697-
weights[:num_decode_tokens],
699+
weights[:num_padded_tokens],
698700
decode_metadata.seq_lens,
699701
decode_metadata.block_table,
700702
decode_metadata.schedule_metadata,
701703
max_model_len=max_model_len,
702704
)
703-
# [B, N, L]
704-
next_n = padded_q_fp8_decode_tokens.shape[1]
705705
# padded query len
706706
current_device = padded_q_fp8_decode_tokens.device
707707
padded_num_tokens = batch_size * next_n
@@ -721,7 +721,7 @@ def sparse_attn_indexer(
721721
topk_indices[topk_indices > index_end_pos] = -1
722722
if decode_metadata.requires_padding:
723723
# if padded, we need to unpack the topk indices removing padded tokens
724-
topk_indices = unpack_seq_triton(topk_indices.reshape(batch_size, -1, logits.shape[-1]), decode_lens)
724+
topk_indices = unpack_seq_triton(topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens)
725725
topk_indices_buffer[:num_decode_tokens, :topk_indices.
726726
shape[-1]] = topk_indices.to(
727727
dtype=torch.int32)

vllm/v1/attention/backends/mla/flashmla_sparse.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,6 @@ class FlashMLASparseMetadataBuilder(
273273
cudagraph_support: ClassVar[AttentionCGSupport] = \
274274
AttentionCGSupport.UNIFORM_BATCH
275275

276-
reorder_batch_threshold: ClassVar[int] = 128 # TODO(lucas): tune this
277-
278276
reorder_batch_threshold: ClassVar[int] = 1
279277

280278
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
@@ -309,7 +307,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
309307
vllm_config.speculative_config.num_speculative_tokens
310308
if vllm_config.speculative_config else 0
311309
)
312-
self.reorder_batch_threshold += self.num_speculative_tokens
310+
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
311+
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
313312

314313
# Equation taken from FlashMLA/csrc/pybind.cpp
315314
h_q, h_k = self.num_heads, 1

vllm/v1/attention/backends/mla/indexer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def __init__(self, *args, **kwargs):
175175
self.vllm_config.speculative_config.num_speculative_tokens
176176
if self.vllm_config.speculative_config else 0
177177
)
178-
self.reorder_batch_threshold += self.num_speculative_tokens
178+
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
179+
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
179180

180181
props = torch.cuda.get_device_properties(self.device)
181182
sm_count = props.multi_processor_count

0 commit comments

Comments
 (0)