Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
dcp_kv_cache_interleave_size: int
eager_mode: bool
chunked_prefill: bool

Expand All @@ -52,6 +53,7 @@ def detailed(
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
dcp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
Expand All @@ -66,6 +68,7 @@ def detailed(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
Expand Down Expand Up @@ -108,6 +111,7 @@ def _compare_cp_with_tp(
tp_size,
pp_size,
dcp_size,
dcp_kv_cache_interleave_size,
eager_mode,
chunked_prefill,
) = parallel_setup
Expand Down Expand Up @@ -180,6 +184,8 @@ def _compare_cp_with_tp(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(dcp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
Expand Down Expand Up @@ -207,6 +213,7 @@ def _compare_cp_with_tp(
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
Expand Down
2 changes: 2 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@ def test_hybrid_block_table_initialization():
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512
dcp_kv_cache_interleave_size = 8

block_table = BlockTable(
block_size=block_size,
Expand All @@ -960,6 +961,7 @@ def test_hybrid_block_table_initialization():
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0],
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
)

# Verify hybrid block configuration
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _correct_attn_cp_out_kernel(
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
lse_max = tl.max(lse, axis=0)
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
Expand Down
11 changes: 11 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ class is dynamically inherited by the worker class. This is used to inject
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""

dcp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using dcp or cp > 1,
store interleave_size tokens on (d)cp i,
then store next interleave_size tokens on (d)cp i+1.
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
Interleave_size=block_size: block-level align, first fill the block on first rank,
token is stored on rank i+1 block j after rank i block j is full.
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
Block_size should be divisible by dcp_kv_cache_interleave_size.
"""

_api_process_count: int = Field(default=1, gt=0)
"""
The number of API processes initialized.
Expand Down
17 changes: 17 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,23 @@ def __post_init__(self):
)
current_platform.check_and_update_config(self)

assert (
self.parallel_config.dcp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.dcp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be "
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
)

assert (
self.parallel_config.dcp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."

# Do this after all the updates to compilation_config.mode
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
self.compilation_config.set_splitting_ops_for_v1()
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: int | None = None
data_parallel_start_rank: int | None = None
Expand Down Expand Up @@ -750,6 +751,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"-dcp",
**parallel_kwargs["decode_context_parallel_size"],
)
parallel_group.add_argument(
"--dcp-kv-cache-interleave-size",
**parallel_kwargs["dcp_kv_cache_interleave_size"],
)
parallel_group.add_argument(
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
)
Expand Down Expand Up @@ -1518,6 +1523,7 @@ def create_engine_config(
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
)
Expand Down
13 changes: 11 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -238,6 +239,10 @@ def __init__(
self.dcp_world_size = 1
self.dcp_rank = 0

self.dcp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)

self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
Expand Down Expand Up @@ -352,8 +357,12 @@ def schedule(
- common_attn_metadata.query_start_loc_cpu[:-1]
)
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size

dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
dcp_context_kv_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
Expand Down
Loading