-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Core] Support disaggregated prefill with Mooncake Transfer Engine #10884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
KuntaiDu
merged 25 commits into
vllm-project:main
from
kvcache-ai:upstream-mooncake-integration
Dec 15, 2024
Merged
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
d52dbc8
Rebase from main to work with PR 10502.
ShangmingCai c8e9d07
Update format of mooncake config ValueError.
ShangmingCai b718f1e
Modify metadata transfer logic to support tp.
ShangmingCai 08e2800
Fix format to make ruff happy.
ShangmingCai 8179746
Add instructions when mooncake is not installed.
ShangmingCai ba82d71
Merge branch 'main' into upstream-mooncake-integration
ShangmingCai 76d484c
Merge branch 'main' into upstream-mooncake-integration
ShangmingCai e912055
fix import order to make isort happy.
ShangmingCai 2396f01
Fix format to make yapf happy.
ShangmingCai 31514a0
Add solution for ports conflict on the same node.
ShangmingCai 2ef10be
Fix format to make mypy happy.
ShangmingCai 0823e47
Get head_size and num_heads from model config to address bugs on Volt…
ShangmingCai 6fb95fb
Add support for other metadata server backend.
ShangmingCai a5758b1
Change code to align with PR 11058.
ShangmingCai 33e4455
Fix typo.
ShangmingCai f3312b9
Merge branch 'main' into upstream-mooncake-integration
ShangmingCai 343c474
Reuse simple connector for mooncake pipe.
ShangmingCai eaa1a45
Remove mooncake connector.
ShangmingCai 83e4db9
fix isort.
ShangmingCai e8ee5c2
fix mypy.
ShangmingCai bc01eae
move PyNcclPipe import to fix mypy.
ShangmingCai b45ff65
still trying to fix mypy.
ShangmingCai aeccf4f
fix typo and fix mypy.
ShangmingCai 8c2135a
trying to fix mypy again.
ShangmingCai 875ca4c
remove unused kvpipe base.
ShangmingCai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
258 changes: 258 additions & 0 deletions
258
vllm/distributed/kv_transfer/kv_connector/mooncake_connector.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,258 @@ | ||
| """ | ||
| Mooncake KV Cache Connector for Distributed Machine Learning Inference | ||
|
|
||
| The MooncakeConnector transfers KV caches between prefill vLLM worker (Cache | ||
| producer) and decode vLLM worker (Cache consumer) using MooncakePipe. | ||
| """ | ||
| import os | ||
| from typing import TYPE_CHECKING, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.config import VllmConfig | ||
| from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase | ||
| from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( | ||
| SimpleBuffer) | ||
| from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import MooncakePipe | ||
| from vllm.logger import init_logger | ||
| from vllm.sequence import IntermediateTensors | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class MooncakeConnector(KVConnectorBase): | ||
|
|
||
| def __init__( | ||
| self, | ||
| rank: int, | ||
| local_rank: int, | ||
| config: VllmConfig, | ||
| ): | ||
|
|
||
| self.config = config.kv_transfer_config | ||
|
|
||
| logger.info( | ||
| "Initializing MooncakeConnector under kv_transfer_config %s", | ||
| self.config) | ||
|
|
||
| self.lookup_buffer_size = self.config.kv_buffer_size | ||
|
|
||
| self.producer_buffer: Optional[SimpleBuffer] = None | ||
| self.consumer_buffer: Optional[SimpleBuffer] = None | ||
|
|
||
| # Check if MOONCAKE_CONFIG_PATH is set | ||
| use_mooncake_distributed_pipe = os.getenv( | ||
| 'MOONCAKE_CONFIG_PATH') is not None | ||
|
|
||
| if not use_mooncake_distributed_pipe: | ||
| raise ValueError( | ||
| "To use MooncakeConnector, you need to pass the env variable: " | ||
| "'MOONCAKE_CONFIG_PATH=/path/to/your/mooncake_config.json'.") | ||
|
|
||
| # In disaggregated prefill, the prefill vLLM only uses send pipe | ||
| # and the decode vLLM only uses recv pipe | ||
| if self.config.is_kv_producer: | ||
|
|
||
| self.producer_data_pipe = MooncakePipe( | ||
ShangmingCai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| local_rank=local_rank, | ||
| config=self.config, | ||
| ) | ||
| # We only need to initialize MooncakePipe once | ||
| self.producer_signal_pipe = self.producer_data_pipe | ||
| self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, | ||
| self.producer_data_pipe, | ||
| self.config.kv_buffer_size) | ||
|
|
||
| else: | ||
|
|
||
| # the current vLLM instance is KV consumer, so it needs to connect | ||
| # its recv pipe to the send pipe of KV producder | ||
| self.consumer_data_pipe = MooncakePipe( | ||
| local_rank=local_rank, | ||
| config=self.config, | ||
| ) | ||
| self.consumer_signal_pipe = self.consumer_data_pipe | ||
| self.consumer_buffer = SimpleBuffer( | ||
| self.consumer_signal_pipe, | ||
| self.consumer_data_pipe, | ||
| self.config.kv_buffer_size, | ||
| ) | ||
|
|
||
| def select(self, input_tokens: Optional[torch.Tensor], | ||
| roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: | ||
|
|
||
| assert self.consumer_buffer is not None, "Please initialize the "\ | ||
| "consumer buffer before calling select." | ||
| return self.consumer_buffer.drop_select(input_tokens, roi) | ||
|
|
||
| def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, | ||
| key: torch.Tensor, value: torch.Tensor, | ||
| hidden: torch.Tensor) -> None: | ||
|
|
||
| assert self.producer_buffer is not None, "Please initialize the "\ | ||
| "producer buffer before calling insert." | ||
|
|
||
| self.producer_buffer.insert(input_tokens, roi, key, value, hidden) | ||
|
|
||
| def send_kv_caches_and_hidden_states( | ||
| self, | ||
| model_executable: torch.nn.Module, | ||
| model_input: "ModelInputForGPUWithSamplingMetadata", | ||
| kv_caches: List[torch.Tensor], | ||
| hidden_or_intermediate_states: Union[torch.Tensor, | ||
| IntermediateTensors], | ||
| ) -> None: | ||
|
|
||
| input_tokens_tensor = model_input.input_tokens | ||
| seq_lens = model_input.attn_metadata.seq_lens | ||
| slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() | ||
| start_layer = model_executable.model.start_layer | ||
| end_layer = model_executable.model.end_layer | ||
|
|
||
| model_config = model_executable.model.config | ||
| num_heads = model_config.num_key_value_heads | ||
| hidden_size = model_config.hidden_size | ||
| num_attention_heads = model_config.num_attention_heads | ||
| head_size = int(hidden_size / num_attention_heads) | ||
|
|
||
| # query_lens contains new KV caches that are added to vLLM. | ||
| # so we will send them to decode instance | ||
| # FIXME(Kuntai): This assume that all requests are prefill. | ||
| for idx, slen in enumerate(seq_lens): | ||
| start_pos = sum(seq_lens[:idx]) | ||
| end_pos = start_pos + slen | ||
| current_tokens = input_tokens_tensor[start_pos:end_pos] | ||
|
|
||
| keys, values = [], [] | ||
|
|
||
| for layer_id in range(start_layer, end_layer): | ||
| kv_cache = kv_caches[layer_id - start_layer] | ||
|
|
||
| key_cache = kv_cache[0].reshape(-1, num_heads, head_size) | ||
| value_cache = kv_cache[1].reshape(-1, num_heads, head_size) | ||
|
|
||
| current_slot_mapping = slot_mapping_flat[start_pos:end_pos] | ||
|
|
||
| keys.append(key_cache[current_slot_mapping].unsqueeze(0)) | ||
| values.append(value_cache[current_slot_mapping].unsqueeze(0)) | ||
|
|
||
| keys = torch.cat(keys, dim=0) | ||
| values = torch.cat(values, dim=0) | ||
|
|
||
| self.insert(current_tokens, | ||
| torch.ones_like(current_tokens, | ||
| dtype=bool), keys, values, | ||
| hidden_or_intermediate_states[start_pos:end_pos]) | ||
|
|
||
| logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) | ||
|
|
||
| def recv_kv_caches_and_hidden_states( | ||
| self, model_executable: torch.nn.Module, | ||
| model_input: "ModelInputForGPUWithSamplingMetadata", | ||
| kv_caches: List[torch.Tensor] | ||
| ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, | ||
| "ModelInputForGPUWithSamplingMetadata"]: | ||
|
|
||
| # When bypass_model_exec is set to False, it means that at least for one | ||
| # request its corresponding KV cache or hidden state is missing. | ||
| # In this case we need to do prefilling to recompute missing KV cache | ||
| # and hidden states. | ||
| bypass_model_exec = True | ||
|
|
||
| input_tokens_tensor = model_input.input_tokens | ||
| seq_lens = model_input.attn_metadata.seq_lens | ||
| slot_mapping = model_input.attn_metadata.slot_mapping.flatten() | ||
|
|
||
| hidden_or_intermediate_states_for_one_req = [] | ||
|
|
||
| input_tokens_list = [] | ||
| num_computed_tokens_list = [] | ||
| start_pos_list = [] | ||
|
|
||
| # enumerate different requests | ||
| # FIXME(Kuntai): This impl assumes that all requests are prefill. | ||
| for idx, slen in enumerate(seq_lens): | ||
|
|
||
| start_pos = sum(seq_lens[:idx]) | ||
| end_pos = start_pos + slen | ||
| current_tokens = input_tokens_tensor[start_pos:end_pos] | ||
| num_tokens = slen | ||
|
|
||
| # collecting data for rebuilding the input | ||
| input_tokens_list.append(current_tokens) | ||
| start_pos_list.append(start_pos) | ||
|
|
||
| ret = self.select(current_tokens, | ||
| torch.ones_like(current_tokens, dtype=bool)) | ||
| if ret[0] is None: | ||
| # didn't find any match. | ||
| bypass_model_exec = False | ||
| num_computed_tokens_list.append(0) | ||
| continue | ||
|
|
||
| roi: torch.Tensor = ret[1] | ||
| keys: torch.Tensor = ret[2] | ||
| values: torch.Tensor = ret[3] | ||
| hidden: torch.Tensor = ret[4] | ||
|
|
||
| num_computed_tokens = roi.shape[0] | ||
| num_computed_tokens_list.append(num_computed_tokens) | ||
|
|
||
| # check if both KV cache and the hidden states are received | ||
| # If not, need to redo the forwarding to compute missing states | ||
| if not all([(num_computed_tokens == num_tokens), hidden is not None | ||
| ]): | ||
| bypass_model_exec = False | ||
|
|
||
| # update the end position based on how many tokens are cached. | ||
| end_pos = start_pos + num_computed_tokens | ||
|
|
||
| # put received KV caches into paged memory | ||
| for i in range(model_executable.model.start_layer, | ||
| model_executable.model.end_layer): | ||
|
|
||
| kv_cache = kv_caches[i - model_executable.model.start_layer] | ||
| layer = model_executable.model.layers[i] | ||
|
|
||
| key_cache, value_cache = kv_cache[0], kv_cache[1] | ||
| ops.reshape_and_cache_flash( | ||
| keys[i - model_executable.model.start_layer].to( | ||
| key_cache.device), | ||
| values[i - model_executable.model.start_layer].to( | ||
| value_cache.device), | ||
| key_cache, | ||
| value_cache, | ||
| slot_mapping[start_pos:end_pos], | ||
| layer.self_attn.attn.kv_cache_dtype, | ||
| layer.self_attn.attn._k_scale, | ||
| layer.self_attn.attn._v_scale, | ||
| ) | ||
|
|
||
| hidden_or_intermediate_states_for_one_req.append(hidden) | ||
|
|
||
| if not bypass_model_exec: | ||
| # Some of the KV cache is not retrieved | ||
| # Here we will fall back to normal model forwarding | ||
| # But optionally you can adjust model_input so that you only do | ||
| # prefilling on those tokens that are missing KV caches. | ||
| logger.debug( | ||
| "[rank%d]: Failed to receive all KVs and hidden " | ||
| "states, redo model forwarding.", torch.distributed.get_rank()) | ||
| hidden_or_intermediate_states = None | ||
|
|
||
| else: | ||
| logger.debug( | ||
| "[rank%d]: Successfully received all KVs and hidden " | ||
| "states, skip model forwarding.", torch.distributed.get_rank()) | ||
| hidden_or_intermediate_states = torch.cat( | ||
| hidden_or_intermediate_states_for_one_req, dim=0) | ||
|
|
||
| return hidden_or_intermediate_states, bypass_model_exec, model_input | ||
|
|
||
| def close(self): | ||
| self.producer_data_pipe.close() | ||
| self.consumer_data_pipe.close() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.