-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[KV Connector] Make KVCacheConfig an explicit constructor argument #27887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,275 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| Unit tests for backwards compatibility with external KV connector implementations. | ||
| This test ensures that external connectors (loaded via kv_connector_module_path) | ||
| implemented with the old signature continue to work: | ||
| - Old signature: __init__(self, vllm_config, role) | ||
| - New signature: __init__(self, vllm_config, role, kv_cache_config) | ||
| """ | ||
|
|
||
| from typing import TYPE_CHECKING | ||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory | ||
| from vllm.distributed.kv_transfer.kv_connector.v1 import ( | ||
| KVConnectorBase_V1, | ||
| KVConnectorRole, | ||
| ) | ||
| from vllm.v1.core.sched.output import SchedulerOutput | ||
|
|
||
| from .utils import create_scheduler, create_vllm_config | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.attention.backends.abstract import AttentionMetadata | ||
| from vllm.config import VllmConfig | ||
| from vllm.forward_context import ForwardContext | ||
| from vllm.v1.core.kv_cache_manager import KVCacheBlocks | ||
| from vllm.v1.kv_cache_interface import KVCacheConfig | ||
| from vllm.v1.request import Request | ||
|
|
||
|
|
||
| class OldStyleTestConnector(KVConnectorBase_V1): | ||
| """ | ||
| Test connector using the old signature with 2 required arguments. | ||
| This simulates external connectors that haven't been updated yet. | ||
| """ | ||
|
|
||
| def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): | ||
| # Old-style call to super().__init__ with only 2 arguments | ||
| super().__init__(vllm_config=vllm_config, role=role) | ||
|
|
||
| def get_num_new_matched_tokens( | ||
| self, request: "Request", num_computed_tokens: int | ||
| ) -> tuple[int | None, bool]: | ||
| return 0, False | ||
|
|
||
| def update_state_after_alloc( | ||
| self, | ||
| request: "Request", | ||
| blocks: "KVCacheBlocks", | ||
| num_external_tokens: int, | ||
| ): | ||
| pass | ||
|
|
||
| def build_connector_meta(self, scheduler_output: SchedulerOutput): | ||
| return None | ||
|
|
||
| def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: | ||
| pass | ||
|
|
||
| def wait_for_layer_load(self, layer_name: str) -> None: | ||
| pass | ||
|
|
||
| def save_kv_layer( | ||
| self, | ||
| layer_name: str, | ||
| kv_layer, | ||
| attn_metadata: "AttentionMetadata", | ||
| **kwargs, | ||
| ) -> None: | ||
| pass | ||
|
|
||
| def wait_for_save(self): | ||
| pass | ||
|
|
||
|
|
||
| class NewStyleTestConnector(KVConnectorBase_V1): | ||
| """ | ||
| Test connector using the new signature with 3 required arguments. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| vllm_config: "VllmConfig", | ||
| role: KVConnectorRole, | ||
| kv_cache_config: "KVCacheConfig", | ||
| ): | ||
| # New-style call to super().__init__ with all 3 arguments | ||
| super().__init__( | ||
| vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config | ||
| ) | ||
|
|
||
| def get_num_new_matched_tokens( | ||
| self, request: "Request", num_computed_tokens: int | ||
| ) -> tuple[int | None, bool]: | ||
| return 0, False | ||
|
|
||
| def update_state_after_alloc( | ||
| self, | ||
| request: "Request", | ||
| blocks: "KVCacheBlocks", | ||
| num_external_tokens: int, | ||
| ): | ||
| pass | ||
|
|
||
| def build_connector_meta(self, scheduler_output: SchedulerOutput): | ||
| return None | ||
|
|
||
| def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: | ||
| pass | ||
|
|
||
| def wait_for_layer_load(self, layer_name: str) -> None: | ||
| pass | ||
|
|
||
| def save_kv_layer( | ||
| self, | ||
| layer_name: str, | ||
| kv_layer, | ||
| attn_metadata: "AttentionMetadata", | ||
| **kwargs, | ||
| ) -> None: | ||
| pass | ||
|
|
||
| def wait_for_save(self): | ||
| pass | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) | ||
| def test_external_old_signature_factory_instantiation(role): | ||
| """ | ||
| Test that external connectors with old signature (2 required args) loaded | ||
| via kv_connector_module_path are correctly instantiated with backwards | ||
| compatibility support. | ||
| """ | ||
| vllm_config = create_vllm_config() | ||
| vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector" | ||
| vllm_config.kv_transfer_config.kv_connector_module_path = ( | ||
| "tests.v1.kv_connector.unit.test_backwards_compatibility" | ||
| ) | ||
|
|
||
| scheduler = create_scheduler(vllm_config) | ||
| kv_cache_config = scheduler.kv_cache_config | ||
|
|
||
| connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) | ||
|
|
||
| assert connector is not None | ||
| assert isinstance(connector, OldStyleTestConnector) | ||
| assert connector.role == role | ||
| assert connector._kv_cache_config is None | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) | ||
| def test_external_new_signature_factory_instantiation(role): | ||
| """ | ||
| Test that external connectors with new signature (3 required args) loaded | ||
| via kv_connector_module_path are correctly instantiated. | ||
| """ | ||
| vllm_config = create_vllm_config() | ||
| vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector" | ||
| vllm_config.kv_transfer_config.kv_connector_module_path = ( | ||
| "tests.v1.kv_connector.unit.test_backwards_compatibility" | ||
| ) | ||
|
|
||
| scheduler = create_scheduler(vllm_config) | ||
| kv_cache_config = scheduler.kv_cache_config | ||
|
|
||
| connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) | ||
|
|
||
| assert connector is not None | ||
| assert isinstance(connector, NewStyleTestConnector) | ||
| assert connector.role == role | ||
| assert connector._kv_cache_config is not None | ||
| assert connector._kv_cache_config == kv_cache_config | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) | ||
| def test_old_signature_super_init(role): | ||
| """ | ||
| Test that old-style connectors can call super().__init__() without | ||
| kv_cache_config parameter. | ||
| """ | ||
| vllm_config = create_vllm_config() | ||
|
|
||
| connector = OldStyleTestConnector(vllm_config, role) | ||
|
|
||
| assert connector is not None | ||
| assert connector.role == role | ||
| assert connector._kv_cache_config is None | ||
|
|
||
|
|
||
| def test_old_signature_super_init_with_kwargs(): | ||
| """ | ||
| Test that old-style connectors can call super().__init__() with keyword | ||
| arguments in different orders. | ||
| """ | ||
| vllm_config = create_vllm_config() | ||
|
|
||
| # Test with vllm_config= and role= kwargs | ||
| connector1 = OldStyleTestConnector( | ||
| vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER | ||
| ) | ||
| assert connector1 is not None | ||
| assert connector1._kv_cache_config is None | ||
|
|
||
| # Test with role= and vllm_config= in reversed order | ||
| connector2 = OldStyleTestConnector( | ||
| role=KVConnectorRole.WORKER, vllm_config=vllm_config | ||
| ) | ||
| assert connector2 is not None | ||
| assert connector2._kv_cache_config is None | ||
|
|
||
|
|
||
| def test_internal_connector_uses_new_signature(): | ||
| """ | ||
| Test that internal connectors (registered in factory) always use the new | ||
| signature and get kv_cache_config. | ||
| """ | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( | ||
| SharedStorageConnector, | ||
| ) | ||
|
|
||
| vllm_config = create_vllm_config() | ||
| vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector" | ||
|
|
||
| scheduler = create_scheduler(vllm_config) | ||
| kv_cache_config = scheduler.kv_cache_config | ||
|
|
||
| connector = KVConnectorFactory.create_connector( | ||
| vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config | ||
| ) | ||
|
|
||
| assert connector is not None | ||
| assert isinstance(connector, SharedStorageConnector) | ||
| assert connector._kv_cache_config is not None | ||
| assert connector._kv_cache_config == kv_cache_config | ||
|
|
||
|
|
||
| def test_signature_detection_with_mocking(): | ||
| """ | ||
| Test that the factory correctly applies compat_sig flag returned from | ||
| _get_connector_class_with_compat. | ||
| """ | ||
| vllm_config = create_vllm_config() | ||
| scheduler = create_scheduler(vllm_config) | ||
| kv_cache_config = scheduler.kv_cache_config | ||
|
|
||
| # Mock _get_connector_class_with_compat to return old-style connector | ||
| with patch.object( | ||
| KVConnectorFactory, | ||
| "_get_connector_class_with_compat", | ||
| return_value=(OldStyleTestConnector, True), | ||
| ): | ||
| old_connector = KVConnectorFactory.create_connector( | ||
| vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config | ||
| ) | ||
| assert old_connector is not None | ||
| assert isinstance(old_connector, OldStyleTestConnector) | ||
| assert old_connector._kv_cache_config is None | ||
|
|
||
| # Mock _get_connector_class_with_compat to return new-style connector | ||
| with patch.object( | ||
| KVConnectorFactory, | ||
| "_get_connector_class_with_compat", | ||
| return_value=(NewStyleTestConnector, False), | ||
| ): | ||
| new_connector = KVConnectorFactory.create_connector( | ||
| vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config | ||
| ) | ||
| assert new_connector is not None | ||
| assert isinstance(new_connector, NewStyleTestConnector) | ||
| assert new_connector._kv_cache_config is not None | ||
| assert new_connector._kv_cache_config == kv_cache_config | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,10 +3,9 @@ | |
|
|
||
| import importlib | ||
| from collections.abc import Callable | ||
| from typing import TYPE_CHECKING, cast | ||
| from typing import TYPE_CHECKING, Optional, cast | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.config import VllmConfig | ||
| from vllm.distributed.kv_transfer.kv_connector.base import ( | ||
| KVConnectorBase, | ||
| KVConnectorBaseType, | ||
|
|
@@ -16,9 +15,12 @@ | |
| supports_hma, | ||
| ) | ||
| from vllm.logger import init_logger | ||
| from vllm.utils.func_utils import supports_kw | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.config import VllmConfig | ||
| from vllm.config.kv_transfer import KVTransferConfig | ||
| from vllm.v1.kv_cache_interface import KVCacheConfig | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -41,8 +43,9 @@ def loader() -> type[KVConnectorBase]: | |
| @classmethod | ||
| def create_connector( | ||
| cls, | ||
| config: VllmConfig, | ||
| config: "VllmConfig", | ||
| role: KVConnectorRole, | ||
| kv_cache_config: Optional["KVCacheConfig"] = None, | ||
| ) -> KVConnectorBase: | ||
| if not envs.VLLM_USE_V1: | ||
| raise ValueError( | ||
|
|
@@ -53,7 +56,9 @@ def create_connector( | |
| kv_transfer_config = config.kv_transfer_config | ||
| if kv_transfer_config is None: | ||
| raise ValueError("kv_transfer_config must be set to create a connector") | ||
| connector_cls = cls.get_connector_class(kv_transfer_config) | ||
| connector_cls, compat_sig = cls._get_connector_class_with_compat( | ||
| kv_transfer_config | ||
| ) | ||
|
|
||
| # check if the connector supports HMA | ||
| hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager | ||
|
|
@@ -76,7 +81,12 @@ def create_connector( | |
| # - Co-locate with worker process | ||
| # - Should only be used inside the forward context & attention layer | ||
| # We build separately to enforce strict separation | ||
| return connector_cls(config, role) | ||
| if compat_sig: | ||
| # Old signature: __init__(self, vllm_config, role) | ||
| return connector_cls(config, role) | ||
| else: | ||
| # New signature: __init__(self, vllm_config, role, kv_cache_config) | ||
| return connector_cls(config, role, kv_cache_config) | ||
|
|
||
| @classmethod | ||
| def get_connector_class_by_name( | ||
|
|
@@ -97,13 +107,13 @@ def get_connector_class_by_name( | |
| return cls._registry[connector_name]() | ||
|
|
||
| @classmethod | ||
| def get_connector_class( | ||
| def _get_connector_class_with_compat( | ||
| cls, kv_transfer_config: "KVTransferConfig" | ||
| ) -> type[KVConnectorBaseType]: | ||
| """Get the connector class by name.""" | ||
| ) -> tuple[type[KVConnectorBaseType], bool]: | ||
| connector_name = kv_transfer_config.kv_connector | ||
| if connector_name is None: | ||
| raise ValueError("Connector name is not set in KVTransferConfig") | ||
| compat_sig = False | ||
| if connector_name in cls._registry: | ||
| connector_cls = cls._registry[connector_name]() | ||
| else: | ||
|
|
@@ -118,6 +128,21 @@ def get_connector_class( | |
| f"Class {connector_name} not found in {connector_module_path}" | ||
| ) from e | ||
| connector_cls = cast(type[KVConnectorBaseType], connector_cls) | ||
| if not supports_kw(connector_cls, "kv_cache_config"): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to confirm: this means that we allow connector to include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. Since we we are currently unconditionally attaching If we think that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That said, the current implementation looks good to me as it is simple enough. |
||
| compat_sig = True | ||
| logger.warning( | ||
| "Connector %s uses deprecated signature with 2 required arguments. " | ||
| "Please update to include kv_cache_config as the second argument.", | ||
| connector_cls.__name__, | ||
| ) | ||
| return connector_cls, compat_sig | ||
|
|
||
| @classmethod | ||
| def get_connector_class( | ||
| cls, kv_transfer_config: "KVTransferConfig" | ||
| ) -> type[KVConnectorBaseType]: | ||
| """Get the connector class by name.""" | ||
| connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config) | ||
| return connector_cls | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename this file? Like
test_connector_init_with_kv_cache_configor something.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Obviously I don't really mind renaming it, if it helps, but my thinking is that these tests are about testing support for connectors that have not yet been updated to the new signature so it's more like "without_kv_cache_config()"
Basically, because all connectors are expected to support the new signature, we'll soon see these tests as old cruft that we need to keep around for a while
(This is different from the
SupportsHMAapproach - in that case, maybe only a small subset of connectors would be updated to takeKVCacheConfig)