Skip to content

Commit 4dfdb82

Browse files
authored
[P/D] Dynamic kv_output_aggregator collect size (#26734)
Signed-off-by: NickLucche <[email protected]>
1 parent 58fab50 commit 4dfdb82

File tree

7 files changed

+90
-19
lines changed

7 files changed

+90
-19
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def test_kv_connector_stats_aggregation():
703703

704704
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
705705
# done in MultiprocExecutor.execute_model
706-
aggregator = KVOutputAggregator(world_size=3)
706+
aggregator = KVOutputAggregator(expected_finished_count=3)
707707

708708
# Create stats for multiple workers with different transfer patterns
709709
worker1_stats = NixlKVConnectorStats()
@@ -768,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation():
768768
KVOutputAggregator (used by MultiprocExecutor).
769769
"""
770770

771-
aggregator = KVOutputAggregator(world_size=3)
771+
aggregator = KVOutputAggregator(expected_finished_count=3)
772772

773773
from dataclasses import dataclass
774774

tests/v1/kv_connector/unit/test_output_aggreagator.py renamed to tests/v1/kv_connector/unit/test_output_aggregator.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ def __init__(
1616
finished_sending: set[str] | None = None,
1717
finished_recving: set[str] | None = None,
1818
invalid_block_ids: set[int] | None = None,
19+
expected_finished_count: int = 0,
1920
):
2021
self.kv_connector_output = KVConnectorOutput(
2122
finished_sending=finished_sending,
2223
finished_recving=finished_recving,
2324
invalid_block_ids=invalid_block_ids or set(),
25+
expected_finished_count=expected_finished_count,
2426
)
2527

2628
def __repr__(self):
@@ -33,7 +35,7 @@ def __repr__(self):
3335

3436

3537
def test_aggregate_workers_output():
36-
aggregator = KVOutputAggregator(world_size=2)
38+
aggregator = KVOutputAggregator(expected_finished_count=2)
3739

3840
output1 = DummyModelRunnerOutput()
3941
output2 = DummyModelRunnerOutput()
@@ -85,7 +87,7 @@ def test_aggregate_workers_output():
8587

8688

8789
def test_async_aggregate_workers_output():
88-
aggregator = KVOutputAggregator(world_size=2)
90+
aggregator = KVOutputAggregator(expected_finished_count=2)
8991

9092
future1: Future[DummyModelRunnerOutput] = Future()
9193
future2: Future[DummyModelRunnerOutput] = Future()
@@ -158,3 +160,40 @@ def test_async_aggregate_workers_output():
158160
assert aggregated.finished_sending is None
159161
assert aggregated.finished_recving == {"req2"}
160162
assert aggregated.invalid_block_ids == {3, 4, 5}
163+
164+
165+
def test_aggregate_workers_output_with_expected_finished_count():
166+
# We create the aggregator expecting to collect from 4 workers
167+
aggregator = KVOutputAggregator(expected_finished_count=4)
168+
assert aggregator._expected_finished_count == 4
169+
# Some request with default expected finished requests
170+
output1 = DummyModelRunnerOutput(finished_sending={"req1"})
171+
aggregated = aggregator.aggregate([output1])
172+
# still expecting to collect from 4 workers
173+
assert aggregator._send_remaining_count["req1"] == 3
174+
assert not aggregated.kv_connector_output.finished_sending
175+
assert not aggregated.kv_connector_output.finished_recving
176+
177+
# Workers discover and find that in this setup they only need to
178+
# collect from 2
179+
output1 = DummyModelRunnerOutput(
180+
finished_sending={"req1"}, expected_finished_count=2
181+
)
182+
output2 = DummyModelRunnerOutput(
183+
finished_recving={"req2"}, expected_finished_count=2
184+
)
185+
output3 = DummyModelRunnerOutput(finished_recving={"req2"})
186+
# Req2 only needs 2 acks
187+
aggregated = aggregator.aggregate([output1, output2, output3])
188+
assert aggregated.kv_connector_output.expected_finished_count == 2
189+
190+
assert not aggregated.kv_connector_output.finished_sending
191+
192+
# Req2 is finished
193+
assert "req2" not in aggregator._recv_remaining_count
194+
assert aggregated.kv_connector_output.finished_recving == {"req2"}
195+
196+
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
197+
# NOTE: This is to showcase dynamic update. Workers are responsible for
198+
# ensuring "req1" termination in this case
199+
assert aggregator._send_remaining_count["req1"] == 2

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
KV cache helper for store.
55
"""
66

7-
from collections import defaultdict
87
from collections.abc import Sequence
98
from concurrent.futures import CancelledError, Future
10-
from typing import Literal, cast
9+
from typing import TYPE_CHECKING, Literal, cast
1110

1211
import torch
1312

@@ -18,6 +17,9 @@
1817
from vllm.logger import init_logger
1918
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
2019

20+
if TYPE_CHECKING:
21+
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
22+
2123
logger = init_logger(__name__)
2224

2325

@@ -124,11 +126,16 @@ class KVOutputAggregator:
124126
"""Utility class to aggregate the output of all workers into a single
125127
output corresponding to Rank 0 for scheduler."""
126128

127-
def __init__(self, world_size: int):
129+
def __init__(self, expected_finished_count: int):
128130
# Complete transfer tracker. Used to track finished requests
129131
# [req_id -> n_remaining_workers]
130-
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
131-
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
132+
self._recv_remaining_count = dict[str, int]()
133+
self._send_remaining_count = dict[str, int]()
134+
self._expected_finished_count = expected_finished_count
135+
136+
@classmethod
137+
def from_connector(cls, connector: "KVConnectorBase", world_size: int):
138+
return cls(connector.get_finished_count() or world_size)
132139

133140
def aggregate(
134141
self, outputs: list[ModelRunnerOutput], output_rank: int = 0
@@ -141,7 +148,10 @@ def update_finished_set(
141148
finished_set: set[str],
142149
) -> None:
143150
for req_id in req_ids or ():
144-
remaining_count_dict[req_id] -= 1
151+
remaining_count = remaining_count_dict.get(
152+
req_id, self._expected_finished_count
153+
)
154+
remaining_count_dict[req_id] = remaining_count - 1
145155
if remaining_count_dict[req_id] == 0:
146156
finished_set.add(req_id)
147157
del remaining_count_dict[req_id]
@@ -154,6 +164,19 @@ def update_finished_set(
154164
kv_output = model_runner_output.kv_connector_output
155165
if not kv_output:
156166
continue
167+
# Allow the worker to dynamically update the expected number of
168+
# finished sending/recving for new requests.
169+
if (
170+
kv_output.expected_finished_count > 0
171+
and kv_output.expected_finished_count != self._expected_finished_count
172+
):
173+
logger.debug(
174+
"Expected finished requests updated from %d to %d",
175+
self._expected_finished_count,
176+
kv_output.expected_finished_count,
177+
)
178+
self._expected_finished_count = kv_output.expected_finished_count
179+
157180
update_finished_set(
158181
kv_output.finished_sending, self._send_remaining_count, finished_sending
159182
)
@@ -186,6 +209,7 @@ def update_finished_set(
186209
finished_recving=finished_recving or None,
187210
kv_connector_stats=aggregated_kv_connector_stats or None,
188211
invalid_block_ids=invalid_block_ids,
212+
expected_finished_count=self._expected_finished_count,
189213
)
190214

191215
return output

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,8 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
413413
def get_finished_count(self) -> int | None:
414414
"""
415415
Get the count of requests expected to complete send/receive operations
416-
via this connector.
416+
via this connector. This method is used to initialize the
417+
KVOutputAggregator, overwriting the default world_size.
417418
418419
Returns:
419420
int: expected sending or receiving completion count.

vllm/v1/engine/core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,7 @@ def __init__(
160160
)
161161
self.use_spec_decode = vllm_config.speculative_config is not None
162162
if self.scheduler.connector is not None: # type: ignore
163-
self.model_executor.init_kv_output_aggregator(
164-
self.scheduler.connector.get_finished_count() # type: ignore
165-
)
163+
self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore
166164

167165
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
168166
self.mm_receiver_cache = engine_receiver_cache_from_config(

vllm/v1/executor/abstract.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Callable
66
from concurrent.futures import Future
77
from functools import cached_property
8-
from typing import Literal, TypeVar, overload
8+
from typing import TYPE_CHECKING, Literal, TypeVar, overload
99

1010
from vllm.config import VllmConfig
1111
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
@@ -19,6 +19,9 @@
1919
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
2020
from vllm.v1.worker.worker_base import WorkerBase
2121

22+
if TYPE_CHECKING:
23+
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
24+
2225
logger = init_logger(__name__)
2326

2427
_R = TypeVar("_R")
@@ -233,10 +236,10 @@ def shutdown(self) -> None:
233236
"""Shutdown the executor."""
234237
self.collective_rpc("shutdown")
235238

236-
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
239+
def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None:
237240
"""Init KVOutputAggregator"""
238-
self.kv_output_aggregator = KVOutputAggregator(
239-
finished_count or self.parallel_config.world_size
241+
self.kv_output_aggregator = KVOutputAggregator.from_connector(
242+
connector, self.parallel_config.world_size
240243
)
241244

242245
@cached_property # Avoid unnecessary RPC calls

vllm/v1/outputs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,14 @@ class KVConnectorOutput:
8686
finished_recving: set[str] | None = None
8787
kv_connector_stats: KVConnectorStats | None = None
8888
# IDs of externally computed KV blocks that failed to load.
89-
# Requests referencing these blocks should be rescheduled to recompute them.
89+
# Requests referencing these blocks should be rescheduled to recompute them
9090
invalid_block_ids: set[int] = field(default_factory=set)
91+
# Configuration describing how many finished sending/receiving
92+
# notifications should be expected for each request. This allows
93+
# handshake-based connectors like Nixl to update the KVOutputAggregator.
94+
# It captures a static setup info and should almost always remain constant
95+
# for a given connector after discovery. Default value entails no change.
96+
expected_finished_count: int = 0
9197

9298
def is_empty(self):
9399
return (

0 commit comments

Comments
 (0)