-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[V1][Prototype] MTP Support #17683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][Prototype] MTP Support #17683
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,217 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import torch | ||
| import torch.nn as nn | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from vllm.config import VllmConfig, set_current_vllm_config | ||
| from vllm.forward_context import set_forward_context | ||
| from vllm.model_executor.model_loader.loader import get_model_loader | ||
| from vllm.model_executor.model_loader.utils import set_default_torch_dtype | ||
| from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP | ||
| from vllm.v1.sample.metadata import SamplingMetadata | ||
|
|
||
|
|
||
| # FIXME(woosuk): The logic here is duplicated with the main sampling code. | ||
| # We should refactor this to reuse the same sampling implementation. | ||
| def compute_probs_and_sample_next_token( | ||
| logits: torch.Tensor, | ||
| sampling_metadata: SamplingMetadata, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if sampling_metadata.all_greedy: | ||
| # For greedy requests, draft_probs is not used in rejection sampling. | ||
| # Therefore, we can just return the logits. | ||
| probs = logits | ||
| next_token_ids = logits.argmax(dim=-1) | ||
| return next_token_ids, probs | ||
|
|
||
| is_greedy = sampling_metadata.temperature == -1 | ||
| temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) | ||
| logits.div_(temperature.view(-1, 1)) | ||
| probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
|
|
||
| # NOTE(woosuk): Currently, we ignore most of the sampling parameters in | ||
| # generating the draft tokens. We only use the temperature. While this | ||
| # could degrade the acceptance rate, it does not affect the distribution | ||
| # of the generated tokens after rejection sampling. | ||
|
|
||
| # TODO(woosuk): Consider seeds. | ||
| q = torch.empty_like(probs) | ||
| q.exponential_() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious why we have separate sampling logic here for MTP?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needs to be cleaned up |
||
| next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) | ||
| if not sampling_metadata.all_random: | ||
| greedy_token_ids = probs.argmax(dim=-1) | ||
| next_token_ids = torch.where( | ||
| is_greedy, | ||
| greedy_token_ids, | ||
| next_token_ids, | ||
| ) | ||
| return next_token_ids, probs | ||
|
|
||
|
|
||
| class MtpProposer: | ||
|
|
||
| def __init__( | ||
| self, | ||
| vllm_config: VllmConfig, | ||
| runner, | ||
| ): | ||
| self.vllm_config = vllm_config | ||
| self.num_speculative_tokens = ( | ||
| vllm_config.speculative_config.num_speculative_tokens) | ||
| self.block_size = vllm_config.cache_config.block_size | ||
| self.runner = runner | ||
|
|
||
| @staticmethod | ||
| def prepare_inputs( | ||
| # [batch_size + 1] | ||
| cu_target_query_lens: torch.Tensor, | ||
| # [batch_size] | ||
| num_rejected_tokens: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| # cu_target_query_lens: [0, a, a + b, a + b + c] | ||
| # num_rejected_tokens: [n1, n2, n3] | ||
| # num_tokens_per_req: [a - n1, b - n2, c - n3] | ||
| # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] | ||
| # token_indices: [0, 1, ..., a - n1 - 1, | ||
| # a, a + 1, ..., a + b - n2 - 1, | ||
| # a + b, a + b + 1, ..., a + b + c - n3 - 1] | ||
|
|
||
| # [0, a, a + b, a + b + c] -> [a, b, c] | ||
| query_len_per_req = (cu_target_query_lens[1:] - | ||
| cu_target_query_lens[:-1]) | ||
| # [a, b, c] -> [a - n1, b - n2, c - n3] | ||
| num_tokens_per_req = query_len_per_req - num_rejected_tokens | ||
|
|
||
| cu_num_tokens = torch.empty_like(cu_target_query_lens) | ||
| torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we move this into triton as well? |
||
| cu_num_tokens[0] = 0 | ||
|
|
||
| # FIXME(woosuk): Avoid synchronization. | ||
| num_tokens = cu_num_tokens[-1].item() | ||
| token_indices = torch.empty( | ||
| num_tokens, | ||
| dtype=torch.int32, | ||
| device=cu_num_tokens.device, | ||
| ) | ||
|
|
||
| batch_size = num_rejected_tokens.shape[0] | ||
| BLOCK_SIZE = 1024 | ||
| prepare_input_kernel[(batch_size, )]( | ||
| token_indices, | ||
| cu_target_query_lens, | ||
| cu_num_tokens, | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| ) | ||
| return cu_num_tokens, token_indices | ||
|
|
||
| def propose( | ||
| self, | ||
| # [num_tokens] | ||
| target_token_ids: torch.Tensor, | ||
| # [num_tokens] | ||
| target_positions: torch.Tensor, | ||
| # [num_tokens, hidden_size] | ||
| target_hidden_states: torch.Tensor, | ||
| # [num_tokens] | ||
| target_slot_mapping: torch.Tensor, | ||
| # [batch_size] | ||
| next_token_ids: torch.Tensor, | ||
| # [batch_size + 1] starting with 0 | ||
| cu_num_tokens: torch.Tensor, | ||
| # [batch_size, max_num_blocks_per_req] | ||
| block_table: torch.Tensor, | ||
| sampling_metadata: SamplingMetadata, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| num_tokens = target_token_ids.shape[0] | ||
| batch_size = next_token_ids.shape[0] | ||
| last_token_indices = cu_num_tokens[1:] - 1 | ||
|
|
||
| input_ids = torch.empty_like(target_token_ids) | ||
| # Shift the input ids by one token. | ||
| # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] | ||
| input_ids[:-1] = target_token_ids[1:] | ||
| # Replace the last token with the next token. | ||
| # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] | ||
| input_ids[last_token_indices] = next_token_ids | ||
|
|
||
| query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] | ||
| max_query_len = query_lens.max().item() | ||
|
|
||
| # FIXME: reorder_batch() needs to be called before build() | ||
| # because fields of attn_metadata_builder needs to be updated. | ||
| # However, currently reorder_batch() takes input_batch and | ||
| # scheduler_output as arguments, we should probably refactor | ||
| # the method to use new data structures which are independent | ||
| # from input_batch and scheduler_output. | ||
| # self.runner.attn_metadata_builder.reorder_batch( | ||
| # input_batch=self.runner.input_batch, | ||
| # scheduler_output=self.runner.scheduler_output, | ||
| # ) | ||
|
|
||
| attn_metadata = self.runner.attn_metadata_builder.build( | ||
| num_reqs=batch_size, | ||
| num_actual_tokens=num_tokens, | ||
| max_query_len=max_query_len, | ||
| common_prefix_len=0, | ||
| ) | ||
|
|
||
| with set_forward_context(attn_metadata, self.vllm_config): | ||
| hidden_states = self.model( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this for single layer prototype?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. I think DeepSeek only released one layer of weights for MTP. This can be extended to run multiple times as a follow up. |
||
| input_ids=input_ids, | ||
| positions=target_positions, | ||
| previous_hidden_states=target_hidden_states, | ||
| ) | ||
| sample_hidden_states = hidden_states[last_token_indices] | ||
| logits = self.model.compute_logits(sample_hidden_states, None) | ||
| draft_token_ids = logits.argmax(dim=-1) | ||
|
|
||
| assert self.num_speculative_tokens == 1 | ||
| # [batch_size, 1] | ||
| return draft_token_ids.view(-1, 1) | ||
|
|
||
| def load_model(self, target_model: nn.Module) -> None: | ||
| loader = get_model_loader(self.vllm_config.load_config) | ||
|
|
||
| draft_model_config = \ | ||
| self.vllm_config.speculative_config.draft_model_config | ||
| # FIXME(lily): This does not handle with distributed inference. | ||
| target_device = self.vllm_config.device_config.device | ||
| # We need to set the vllm_config here to register attention | ||
| # layers in the forward context. | ||
| with set_default_torch_dtype( | ||
| draft_model_config.dtype), set_current_vllm_config( | ||
| self.vllm_config): | ||
| self.model = DeepSeekMTP( | ||
| vllm_config=self.vllm_config).to(target_device) | ||
|
|
||
| self.model.load_weights( | ||
| loader.get_all_weights( | ||
| self.vllm_config.speculative_config.draft_model_config, | ||
| self.model)) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def prepare_input_kernel( | ||
| out_ptr, | ||
| cu_query_lens_ptr, | ||
| cu_num_tokens_ptr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(0) | ||
|
|
||
| # [start_pos, end_pos) | ||
| start_pos = tl.load(cu_num_tokens_ptr + pid) | ||
| end_pos = tl.load(cu_num_tokens_ptr + pid + 1) | ||
| num_tokens = end_pos - start_pos | ||
|
|
||
| index_start = tl.load(cu_query_lens_ptr + pid) | ||
|
|
||
| num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) | ||
| for i in tl.range(num_blocks): | ||
| offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| tl.store( | ||
| out_ptr + start_pos + offset, | ||
| index_start + offset, | ||
| mask=offset < num_tokens, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly import triton is not recommended in vllm to make non-triton devices work. I suggest merge this pr after #17716 . cc @mgoin