Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 275 additions & 0 deletions tests/v1/kv_connector/unit/test_backwards_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this file? Like test_connector_init_with_kv_cache_config or something.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously I don't really mind renaming it, if it helps, but my thinking is that these tests are about testing support for connectors that have not yet been updated to the new signature so it's more like "without_kv_cache_config()"

Basically, because all connectors are expected to support the new signature, we'll soon see these tests as old cruft that we need to keep around for a while

(This is different from the SupportsHMA approach - in that case, maybe only a small subset of connectors would be updated to take KVCacheConfig)

# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for backwards compatibility with external KV connector implementations.
This test ensures that external connectors (loaded via kv_connector_module_path)
implemented with the old signature continue to work:
- Old signature: __init__(self, vllm_config, role)
- New signature: __init__(self, vllm_config, role, kv_cache_config)
"""

from typing import TYPE_CHECKING
from unittest.mock import patch

import pytest

from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.v1.core.sched.output import SchedulerOutput

from .utils import create_scheduler, create_vllm_config

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request


class OldStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the old signature with 2 required arguments.
This simulates external connectors that haven't been updated yet.
"""

def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
# Old-style call to super().__init__ with only 2 arguments
super().__init__(vllm_config=vllm_config, role=role)

def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False

def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass

def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass

def wait_for_layer_load(self, layer_name: str) -> None:
pass

def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
**kwargs,
) -> None:
pass

def wait_for_save(self):
pass


class NewStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the new signature with 3 required arguments.
"""

def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
# New-style call to super().__init__ with all 3 arguments
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)

def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False

def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass

def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass

def wait_for_layer_load(self, layer_name: str) -> None:
pass

def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
**kwargs,
) -> None:
pass

def wait_for_save(self):
pass


@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_old_signature_factory_instantiation(role):
"""
Test that external connectors with old signature (2 required args) loaded
via kv_connector_module_path are correctly instantiated with backwards
compatibility support.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)

scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config

connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)

assert connector is not None
assert isinstance(connector, OldStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is None


@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_new_signature_factory_instantiation(role):
"""
Test that external connectors with new signature (3 required args) loaded
via kv_connector_module_path are correctly instantiated.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)

scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config

connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)

assert connector is not None
assert isinstance(connector, NewStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config


@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_old_signature_super_init(role):
"""
Test that old-style connectors can call super().__init__() without
kv_cache_config parameter.
"""
vllm_config = create_vllm_config()

connector = OldStyleTestConnector(vllm_config, role)

assert connector is not None
assert connector.role == role
assert connector._kv_cache_config is None


def test_old_signature_super_init_with_kwargs():
"""
Test that old-style connectors can call super().__init__() with keyword
arguments in different orders.
"""
vllm_config = create_vllm_config()

# Test with vllm_config= and role= kwargs
connector1 = OldStyleTestConnector(
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
)
assert connector1 is not None
assert connector1._kv_cache_config is None

# Test with role= and vllm_config= in reversed order
connector2 = OldStyleTestConnector(
role=KVConnectorRole.WORKER, vllm_config=vllm_config
)
assert connector2 is not None
assert connector2._kv_cache_config is None


def test_internal_connector_uses_new_signature():
"""
Test that internal connectors (registered in factory) always use the new
signature and get kv_cache_config.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
SharedStorageConnector,
)

vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector"

scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config

connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)

assert connector is not None
assert isinstance(connector, SharedStorageConnector)
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config


def test_signature_detection_with_mocking():
"""
Test that the factory correctly applies compat_sig flag returned from
_get_connector_class_with_compat.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config

# Mock _get_connector_class_with_compat to return old-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(OldStyleTestConnector, True),
):
old_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert old_connector is not None
assert isinstance(old_connector, OldStyleTestConnector)
assert old_connector._kv_cache_config is None

# Mock _get_connector_class_with_compat to return new-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(NewStyleTestConnector, False),
):
new_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert new_connector is not None
assert isinstance(new_connector, NewStyleTestConnector)
assert new_connector._kv_cache_config is not None
assert new_connector._kv_cache_config == kv_cache_config
2 changes: 1 addition & 1 deletion tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def create_model_runner_output(


class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role):
def __init__(self, config: VllmConfig, role, kv_cache_config):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
Expand Down
41 changes: 33 additions & 8 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Optional, cast

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase,
KVConnectorBaseType,
Expand All @@ -16,9 +15,12 @@
supports_hma,
)
from vllm.logger import init_logger
from vllm.utils.func_utils import supports_kw

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.v1.kv_cache_interface import KVCacheConfig

logger = init_logger(__name__)

Expand All @@ -41,8 +43,9 @@ def loader() -> type[KVConnectorBase]:
@classmethod
def create_connector(
cls,
config: VllmConfig,
config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
) -> KVConnectorBase:
if not envs.VLLM_USE_V1:
raise ValueError(
Expand All @@ -53,7 +56,9 @@ def create_connector(
kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(kv_transfer_config)
connector_cls, compat_sig = cls._get_connector_class_with_compat(
kv_transfer_config
)

# check if the connector supports HMA
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
Expand All @@ -76,7 +81,12 @@ def create_connector(
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return connector_cls(config, role)
if compat_sig:
# Old signature: __init__(self, vllm_config, role)
return connector_cls(config, role)
else:
# New signature: __init__(self, vllm_config, role, kv_cache_config)
return connector_cls(config, role, kv_cache_config)

@classmethod
def get_connector_class_by_name(
Expand All @@ -97,13 +107,13 @@ def get_connector_class_by_name(
return cls._registry[connector_name]()

@classmethod
def get_connector_class(
def _get_connector_class_with_compat(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
) -> tuple[type[KVConnectorBaseType], bool]:
connector_name = kv_transfer_config.kv_connector
if connector_name is None:
raise ValueError("Connector name is not set in KVTransferConfig")
compat_sig = False
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
Expand All @@ -118,6 +128,21 @@ def get_connector_class(
f"Class {connector_name} not found in {connector_module_path}"
) from e
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
if not supports_kw(connector_cls, "kv_cache_config"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm: this means that we allow connector to include kv_cache_config field as init args even when it does not support hybrid allocator (not a subclass of SupportsHMA)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Since we we are currently unconditionally attaching KVCacheConfig to VllmConfig, I didn't even consider making it conditional until now

If we think that KVCacheConfig will only be used as part of implementing request_finished_all_groups() we could add init_hma() or init_kv_cache_config() to SupportsHMA ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KVCacheConfig takes effect on almost all connector methods when HMA is enabled and it is not useful at all when HMA is disabled.

That said, the current implementation looks good to me as it is simple enough.

compat_sig = True
logger.warning(
"Connector %s uses deprecated signature with 2 required arguments. "
"Please update to include kv_cache_config as the second argument.",
connector_cls.__name__,
)
return connector_cls, compat_sig

@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
return connector_cls


Expand Down
Loading