Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d52dbc8
Rebase from main to work with PR 10502.
ShangmingCai Dec 2, 2024
c8e9d07
Update format of mooncake config ValueError.
ShangmingCai Dec 2, 2024
b718f1e
Modify metadata transfer logic to support tp.
ShangmingCai Dec 3, 2024
08e2800
Fix format to make ruff happy.
ShangmingCai Dec 3, 2024
8179746
Add instructions when mooncake is not installed.
ShangmingCai Dec 3, 2024
ba82d71
Merge branch 'main' into upstream-mooncake-integration
ShangmingCai Dec 3, 2024
76d484c
Merge branch 'main' into upstream-mooncake-integration
ShangmingCai Dec 4, 2024
e912055
fix import order to make isort happy.
ShangmingCai Dec 4, 2024
2396f01
Fix format to make yapf happy.
ShangmingCai Dec 4, 2024
31514a0
Add solution for ports conflict on the same node.
ShangmingCai Dec 4, 2024
2ef10be
Fix format to make mypy happy.
ShangmingCai Dec 4, 2024
0823e47
Get head_size and num_heads from model config to address bugs on Volt…
ShangmingCai Dec 10, 2024
6fb95fb
Add support for other metadata server backend.
ShangmingCai Dec 10, 2024
a5758b1
Change code to align with PR 11058.
ShangmingCai Dec 10, 2024
33e4455
Fix typo.
ShangmingCai Dec 11, 2024
f3312b9
Merge branch 'main' into upstream-mooncake-integration
ShangmingCai Dec 15, 2024
343c474
Reuse simple connector for mooncake pipe.
ShangmingCai Dec 15, 2024
eaa1a45
Remove mooncake connector.
ShangmingCai Dec 15, 2024
83e4db9
fix isort.
ShangmingCai Dec 15, 2024
e8ee5c2
fix mypy.
ShangmingCai Dec 15, 2024
bc01eae
move PyNcclPipe import to fix mypy.
ShangmingCai Dec 15, 2024
b45ff65
still trying to fix mypy.
ShangmingCai Dec 15, 2024
aeccf4f
fix typo and fix mypy.
ShangmingCai Dec 15, 2024
8c2135a
trying to fix mypy again.
ShangmingCai Dec 15, 2024
875ca4c
remove unused kvpipe base.
ShangmingCai Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,13 +2103,14 @@ def from_cli(cls, cli_value: str) -> "KVTransferConfig":
return KVTransferConfig.model_validate_json(cli_value)

def model_post_init(self, __context: Any) -> None:
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if all([
self.kv_connector is not None,
self.kv_connector != "PyNcclConnector"
self.kv_connector is not None, self.kv_connector
not in supported_kv_connector
]):
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
f"Supported connectors are "
f"`PyNcclConnector`.")
f"{supported_kv_connector}.")

if self.kv_role is not None and self.kv_role not in [
"kv_producer", "kv_consumer", "kv_both"
Expand Down
3 changes: 3 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def create_connector(rank: int, local_rank: int,
if config.kv_transfer_config.kv_connector == 'PyNcclConnector':
from .simple_connector import SimpleConnector
return SimpleConnector(rank, local_rank, config)
elif config.kv_transfer_config.kv_connector == 'MooncakeConnector':
from .mooncake_connector import MooncakeConnector
return MooncakeConnector(rank, local_rank, config)
else:
raise ValueError(f"Unsupported connector type: "
f"{config.kv_connector}")
258 changes: 258 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/mooncake_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
"""
Mooncake KV Cache Connector for Distributed Machine Learning Inference

The MooncakeConnector transfers KV caches between prefill vLLM worker (Cache
producer) and decode vLLM worker (Cache consumer) using MooncakePipe.
"""
import os
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch

from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import MooncakePipe
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

logger = init_logger(__name__)


class MooncakeConnector(KVConnectorBase):

def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):

self.config = config.kv_transfer_config

logger.info(
"Initializing MooncakeConnector under kv_transfer_config %s",
self.config)

self.lookup_buffer_size = self.config.kv_buffer_size

self.producer_buffer: Optional[SimpleBuffer] = None
self.consumer_buffer: Optional[SimpleBuffer] = None

# Check if MOONCAKE_CONFIG_PATH is set
use_mooncake_distributed_pipe = os.getenv(
'MOONCAKE_CONFIG_PATH') is not None

if not use_mooncake_distributed_pipe:
raise ValueError(
"To use MooncakeConnector, you need to pass the env variable: "
"'MOONCAKE_CONFIG_PATH=/path/to/your/mooncake_config.json'.")

# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:

self.producer_data_pipe = MooncakePipe(
local_rank=local_rank,
config=self.config,
)
# We only need to initialize MooncakePipe once
self.producer_signal_pipe = self.producer_data_pipe
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
self.producer_data_pipe,
self.config.kv_buffer_size)

else:

# the current vLLM instance is KV consumer, so it needs to connect
# its recv pipe to the send pipe of KV producder
self.consumer_data_pipe = MooncakePipe(
local_rank=local_rank,
config=self.config,
)
self.consumer_signal_pipe = self.consumer_data_pipe
self.consumer_buffer = SimpleBuffer(
self.consumer_signal_pipe,
self.consumer_data_pipe,
self.config.kv_buffer_size,
)

def select(self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:

assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
return self.consumer_buffer.drop_select(input_tokens, roi)

def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:

assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."

self.producer_buffer.insert(input_tokens, roi, key, value, hidden)

def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:

input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer

model_config = model_executable.model.config
num_heads = model_config.num_key_value_heads
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)

# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]

keys, values = [], []

for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]

key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)

current_slot_mapping = slot_mapping_flat[start_pos:end_pos]

keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))

keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)

self.insert(current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])

logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())

def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:

# When bypass_model_exec is set to False, it means that at least for one
# request its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states.
bypass_model_exec = True

input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()

hidden_or_intermediate_states_for_one_req = []

input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []

# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):

start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen

# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)

ret = self.select(current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
bypass_model_exec = False
num_computed_tokens_list.append(0)
continue

roi: torch.Tensor = ret[1]
keys: torch.Tensor = ret[2]
values: torch.Tensor = ret[3]
hidden: torch.Tensor = ret[4]

num_computed_tokens = roi.shape[0]
num_computed_tokens_list.append(num_computed_tokens)

# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False

# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens

# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):

kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]

key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)

hidden_or_intermediate_states_for_one_req.append(hidden)

if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.debug(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None

else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)

return hidden_or_intermediate_states, bypass_model_exec, model_input

def close(self):
self.producer_data_pipe.close()
self.consumer_data_pipe.close()
Loading
Loading