diff --git a/tests/conftest.py b/tests/conftest.py index f02b5a8c0520..cc528584d524 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -759,6 +759,8 @@ def __init__( enforce_eager: Optional[bool] = False, **kwargs, ) -> None: + from vllm import envs + logger.info(f"VLLM_USE_V1: {envs.VLLM_USE_V1}") self.model = LLM( model=model_name, task=task, diff --git a/tests/spec_decode/conftest.py b/tests/spec_decode/conftest.py index 1a20e2c135c2..eca20289aa46 100644 --- a/tests/spec_decode/conftest.py +++ b/tests/spec_decode/conftest.py @@ -8,4 +8,4 @@ def use_v0_only(monkeypatch): Since this module is V0 only, set VLLM_USE_V1=0 for all tests in the module. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv('VLLM_USE_V1', '1') diff --git a/vllm/config.py b/vllm/config.py index e96d872d693e..b92e9b74be56 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2416,6 +2416,9 @@ def __post_init__(self): elif (self.draft_model_config.hf_config.model_type == "mlp_speculator"): self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type == + "deepseek_mtp"): + self.method = "mtp" else: self.method = "draft_model" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c7a580cf1051..405d4af3fa7f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -957,6 +957,8 @@ def create_engine_config( else: envs.set_vllm_use_v1(use_v1) + logger.info("use_v1: %s", use_v1) + # Set default arguments for V0 or V1 Engine. if use_v1: self._set_default_args_v1(usage_context) @@ -1281,6 +1283,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: speculative_model = self.speculative_config.get("model") if speculative_model in ("ngram", "[ngram]"): is_ngram_enabled = True + logger.info("Forcing to use V1 for speculative decoding.") + return True if not (is_ngram_enabled or is_eagle_enabled): # Other speculative decoding methods are not supported yet. _raise_or_fallback(feature_name="Speculative Decoding", diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 79f1d80f402c..f4308b06c791 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -240,6 +240,8 @@ def __init__( **kwargs, ) + from vllm import envs + logger.info(f"VLLM_USE_V1: {envs.VLLM_USE_V1}") # Create the Engine (autoselects V0 vs V1) self.llm_engine = LLMEngine.from_engine_args( engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index fd3be901f4c3..952ff1d1d8e3 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -304,6 +304,7 @@ class MLACommonMetadata(Generic[D]): num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor + block_table: torch.Tensor slot_mapping: torch.Tensor # New for MLA (compared to FlashAttention) @@ -341,6 +342,7 @@ def __init__(self, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata + logger.info(f"self.metadata_cls: {self.metadata_cls}") self.runner = runner scheduler_config = runner.scheduler_config model_config = runner.model_config @@ -352,8 +354,9 @@ def __init__(self, self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) # Dont try to access the runner on AMD - if self.aot_schedule: - self.page_size = self.runner.block_size + #if self.aot_schedule: + # Need page_size to compute max_context_chunk + self.page_size = self.runner.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -557,6 +560,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, return self.metadata_cls( num_actual_tokens=num_actual_tokens, query_start_loc=query_start_loc, + block_table=block_table, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific diff --git a/vllm/v1/spec_decode/mtp_proposer.py b/vllm/v1/spec_decode/mtp_proposer.py new file mode 100644 index 000000000000..b4b96bfc1eae --- /dev/null +++ b/vllm/v1/spec_decode/mtp_proposer.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.forward_context import set_forward_context +from vllm.model_executor.model_loader.loader import get_model_loader +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP +from vllm.v1.sample.metadata import SamplingMetadata + + +# FIXME(woosuk): The logic here is duplicated with the main sampling code. +# We should refactor this to reuse the same sampling implementation. +def compute_probs_and_sample_next_token( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> tuple[torch.Tensor, torch.Tensor]: + if sampling_metadata.all_greedy: + # For greedy requests, draft_probs is not used in rejection sampling. + # Therefore, we can just return the logits. + probs = logits + next_token_ids = logits.argmax(dim=-1) + return next_token_ids, probs + + is_greedy = sampling_metadata.temperature == -1 + temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits.div_(temperature.view(-1, 1)) + probs = logits.softmax(dim=-1, dtype=torch.float32) + + # NOTE(woosuk): Currently, we ignore most of the sampling parameters in + # generating the draft tokens. We only use the temperature. While this + # could degrade the acceptance rate, it does not affect the distribution + # of the generated tokens after rejection sampling. + + # TODO(woosuk): Consider seeds. + q = torch.empty_like(probs) + q.exponential_() + next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + if not sampling_metadata.all_random: + greedy_token_ids = probs.argmax(dim=-1) + next_token_ids = torch.where( + is_greedy, + greedy_token_ids, + next_token_ids, + ) + return next_token_ids, probs + + +class MtpProposer: + + def __init__( + self, + vllm_config: VllmConfig, + runner, + ): + self.vllm_config = vllm_config + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens) + self.block_size = vllm_config.cache_config.block_size + self.runner = runner + + @staticmethod + def prepare_inputs( + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) + + batch_size = num_rejected_tokens.shape[0] + BLOCK_SIZE = 1024 + prepare_input_kernel[(batch_size, )]( + token_indices, + cu_target_query_lens, + cu_num_tokens, + BLOCK_SIZE=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + + input_ids = torch.empty_like(target_token_ids) + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + input_ids[:-1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + input_ids[last_token_indices] = next_token_ids + + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + # FIXME: reorder_batch() needs to be called before build() + # because fields of attn_metadata_builder needs to be updated. + # However, currently reorder_batch() takes input_batch and + # scheduler_output as arguments, we should probably refactor + # the method to use new data structures which are independent + # from input_batch and scheduler_output. + # self.runner.attn_metadata_builder.reorder_batch( + # input_batch=self.runner.input_batch, + # scheduler_output=self.runner.scheduler_output, + # ) + + attn_metadata = self.runner.attn_metadata_builder.build( + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + common_prefix_len=0, + ) + + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + positions=target_positions, + previous_hidden_states=target_hidden_states, + ) + sample_hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + draft_token_ids = logits.argmax(dim=-1) + + assert self.num_speculative_tokens == 1 + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + def load_model(self, target_model: nn.Module) -> None: + loader = get_model_loader(self.vllm_config.load_config) + + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config + # FIXME(lily): This does not handle with distributed inference. + target_device = self.vllm_config.device_config.device + # We need to set the vllm_config here to register attention + # layers in the forward context. + with set_default_torch_dtype( + draft_model_config.dtype), set_current_vllm_config( + self.vllm_config): + self.model = DeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) + + self.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model)) + + +@triton.jit +def prepare_input_kernel( + out_ptr, + cu_query_lens_ptr, + cu_num_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + # [start_pos, end_pos) + start_pos = tl.load(cu_num_tokens_ptr + pid) + end_pos = tl.load(cu_num_tokens_ptr + pid + 1) + num_tokens = end_pos - start_pos + + index_start = tl.load(cu_query_lens_ptr + pid) + + num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) + for i in tl.range(num_blocks): + offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store( + out_ptr + start_pos + offset, + index_start + offset, + mask=offset < num_tokens, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97d8c91b4659..e533fa781850 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -41,6 +41,7 @@ from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.mtp_proposer import MtpProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache @@ -140,6 +141,7 @@ def __init__( self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) + print(f"self.attn_metadata_builder: {self.attn_metadata_builder}") self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -176,6 +178,9 @@ def __init__( self.device) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True + elif self.speculative_config.method == "mtp": + self.drafter = MtpProposer(self.vllm_config, + self) # type: ignore else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1191,6 +1196,7 @@ def execute_model( sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: + # GPU tensor to CPU list? sync point? # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() else: @@ -1210,8 +1216,10 @@ def execute_model( assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) - elif self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + elif (self.speculative_config.use_eagle() or + self.speculative_config.draft_model_config.hf_config.model_type \ + == "deepseek_mtp"): + assert isinstance(self.drafter, (EagleProposer, MtpProposer)) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids):