Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,11 +754,13 @@
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: Optional[bool] = False,

Check failure on line 757 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

tests/conftest.py:757:21: G004 Logging statement uses f-string
swap_space: int = 4,
enforce_eager: Optional[bool] = False,
**kwargs,
) -> None:
from vllm import envs
logger.info(f"VLLM_USE_V1: {envs.VLLM_USE_V1}")
self.model = LLM(
model=model_name,
task=task,
Expand Down
2 changes: 1 addition & 1 deletion tests/spec_decode/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def use_v0_only(monkeypatch):
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
monkeypatch.setenv('VLLM_USE_V1', '1')
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2416,6 +2416,9 @@
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type ==
"deepseek_mtp"):
self.method = "mtp"
else:
self.method = "draft_model"

Expand All @@ -2436,7 +2439,7 @@
self.draft_model_config.hf_config,
method=self.method)
self.draft_model_config.hf_config = eagle_config

Check failure on line 2442 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "Literal['mtp']", variable has type "Literal['ngram', 'eagle', 'medusa', 'mlp_speculator', 'draft_model'] | None") [assignment]

Check failure on line 2442 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "Literal['mtp']", variable has type "Literal['ngram', 'eagle', 'medusa', 'mlp_speculator', 'draft_model'] | None") [assignment]

Check failure on line 2442 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "Literal['mtp']", variable has type "Literal['ngram', 'eagle', 'medusa', 'mlp_speculator', 'draft_model'] | None") [assignment]

Check failure on line 2442 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "Literal['mtp']", variable has type "Optional[Literal['ngram', 'eagle', 'medusa', 'mlp_speculator', 'draft_model']]") [assignment]

Check failure on line 2442 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "Literal['mtp']", variable has type "Optional[Literal['ngram', 'eagle', 'medusa', 'mlp_speculator', 'draft_model']]") [assignment]

Check failure on line 2442 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "Literal['mtp']", variable has type "Optional[Literal['ngram', 'eagle', 'medusa', 'mlp_speculator', 'draft_model']]") [assignment]

Check failure on line 2442 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "Literal['mtp']", variable has type "Optional[Literal['ngram', 'eagle', 'medusa', 'mlp_speculator', 'draft_model']]") [assignment]
if (self.num_speculative_tokens is not None
and hasattr(self.draft_model_config.hf_config,
"num_lookahead_tokens")):
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,8 @@ def create_engine_config(
else:
envs.set_vllm_use_v1(use_v1)

logger.info("use_v1: %s", use_v1)

# Set default arguments for V0 or V1 Engine.
if use_v1:
self._set_default_args_v1(usage_context)
Expand Down Expand Up @@ -1281,6 +1283,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
speculative_model = self.speculative_config.get("model")
if speculative_model in ("ngram", "[ngram]"):
is_ngram_enabled = True
logger.info("Forcing to use V1 for speculative decoding.")
return True
if not (is_ngram_enabled or is_eagle_enabled):
# Other speculative decoding methods are not supported yet.
_raise_or_fallback(feature_name="Speculative Decoding",
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@
**kwargs,
)

from vllm import envs
logger.info(f"VLLM_USE_V1: {envs.VLLM_USE_V1}")
# Create the Engine (autoselects V0 vs V1)

Check failure on line 245 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/entrypoints/llm.py:245:21: G004 Logging statement uses f-string
self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
self.engine_class = type(self.llm_engine)
Expand Down
8 changes: 6 additions & 2 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@

num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor

# New for MLA (compared to FlashAttention)
Expand Down Expand Up @@ -341,8 +342,9 @@
metadata_cls: Optional[type[M]] = None):
self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata
logger.info(f"self.metadata_cls: {self.metadata_cls}")
self.runner = runner
scheduler_config = runner.scheduler_config

Check failure on line 347 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/v1/attention/backends/mla/common.py:347:21: G004 Logging statement uses f-string
model_config = runner.model_config
cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
Expand All @@ -352,8 +354,9 @@
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)

# Dont try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.runner.block_size
#if self.aot_schedule:
# Need page_size to compute max_context_chunk
self.page_size = self.runner.block_size

if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
Expand Down Expand Up @@ -557,6 +560,7 @@
return self.metadata_cls(
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc,
block_table=block_table,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
# MLACommonMetadata Chunk prefill specific
Expand Down
217 changes: 217 additions & 0 deletions vllm/v1/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import triton
Copy link
Contributor

@MengqingCao MengqingCao May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly import triton is not recommended in vllm to make non-triton devices work. I suggest merge this pr after #17716 . cc @mgoin

import triton.language as tl

from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.v1.sample.metadata import SamplingMetadata


# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs = logits
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs

is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)

# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.

# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q.exponential_()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why we have separate sampling logic here for MTP?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to be cleaned up

next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
)
return next_token_ids, probs


class MtpProposer:

def __init__(
self,
vllm_config: VllmConfig,
runner,
):
self.vllm_config = vllm_config
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.block_size = vllm_config.cache_config.block_size
self.runner = runner

@staticmethod
def prepare_inputs(
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]

# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens

cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we move this into triton as well?

cu_num_tokens[0] = 0

# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
)

batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_input_kernel[(batch_size, )](
token_indices,
cu_target_query_lens,
cu_num_tokens,
BLOCK_SIZE=BLOCK_SIZE,
)
return cu_num_tokens, token_indices

def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1

input_ids = torch.empty_like(target_token_ids)
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids[:-1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids

query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()

# FIXME: reorder_batch() needs to be called before build()
# because fields of attn_metadata_builder needs to be updated.
# However, currently reorder_batch() takes input_batch and
# scheduler_output as arguments, we should probably refactor
# the method to use new data structures which are independent
# from input_batch and scheduler_output.
# self.runner.attn_metadata_builder.reorder_batch(
# input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output,
# )

attn_metadata = self.runner.attn_metadata_builder.build(
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0,
)

with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for single layer prototype?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. I think DeepSeek only released one layer of weights for MTP. This can be extended to run multiple times as a follow up.

input_ids=input_ids,
positions=target_positions,
previous_hidden_states=target_hidden_states,
)
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)

assert self.num_speculative_tokens == 1
# [batch_size, 1]
return draft_token_ids.view(-1, 1)

def load_model(self, target_model: nn.Module) -> None:
loader = get_model_loader(self.vllm_config.load_config)

draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
# FIXME(lily): This does not handle with distributed inference.
target_device = self.vllm_config.device_config.device
# We need to set the vllm_config here to register attention
# layers in the forward context.
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = DeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)

self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))


@triton.jit
def prepare_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)

# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos

index_start = tl.load(cu_query_lens_ptr + pid)

num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)
12 changes: 10 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.mtp_proposer import MtpProposer
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(

self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
print(f"self.attn_metadata_builder: {self.attn_metadata_builder}")
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn

# Multi-modal data support
Expand Down Expand Up @@ -176,6 +178,9 @@ def __init__(
self.device) # type: ignore
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "mtp":
self.drafter = MtpProposer(self.vllm_config,
self) # type: ignore
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
Expand Down Expand Up @@ -1191,6 +1196,7 @@ def execute_model(
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# GPU tensor to CPU list? sync point?
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
Expand All @@ -1210,8 +1216,10 @@ def execute_model(
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
elif (self.speculative_config.use_eagle() or
self.speculative_config.draft_model_config.hf_config.model_type \
== "deepseek_mtp"):
assert isinstance(self.drafter, (EagleProposer, MtpProposer))
# TODO(woosuk): Refactor the loop.
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
Expand Down