Skip to content

Commit 1fb4217

Browse files
authored
[Multimodal] Make MediaConnector extensible. (#27759)
Signed-off-by: Chenheli Hua <[email protected]>
1 parent 611c86e commit 1fb4217

File tree

5 files changed

+71
-22
lines changed

5 files changed

+71
-22
lines changed

vllm/entrypoints/chat_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@
4343
# pydantic needs the TypedDict from typing_extensions
4444
from typing_extensions import Required, TypedDict
4545

46+
from vllm import envs
4647
from vllm.config import ModelConfig
4748
from vllm.logger import init_logger
4849
from vllm.model_executor.models import SupportsMultiModal
4950
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
50-
from vllm.multimodal.utils import MediaConnector
51+
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
5152
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
5253
from vllm.transformers_utils.processor import cached_get_processor
5354
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@@ -806,7 +807,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
806807
self._tracker = tracker
807808
multimodal_config = self._tracker.model_config.multimodal_config
808809
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
809-
self._connector = MediaConnector(
810+
811+
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
812+
envs.VLLM_MEDIA_CONNECTOR,
810813
media_io_kwargs=media_io_kwargs,
811814
allowed_local_media_path=tracker.allowed_local_media_path,
812815
allowed_media_domains=tracker.allowed_media_domains,
@@ -891,7 +894,8 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
891894
self._tracker = tracker
892895
multimodal_config = self._tracker.model_config.multimodal_config
893896
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
894-
self._connector = MediaConnector(
897+
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
898+
envs.VLLM_MEDIA_CONNECTOR,
895899
media_io_kwargs=media_io_kwargs,
896900
allowed_local_media_path=tracker.allowed_local_media_path,
897901
allowed_media_domains=tracker.allowed_media_domains,

vllm/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
7171
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
7272
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
73+
VLLM_MEDIA_CONNECTOR: str = "http"
7374
VLLM_MM_INPUT_CACHE_GIB: int = 4
7475
VLLM_TARGET_DEVICE: str = "cuda"
7576
VLLM_MAIN_CUDA_VERSION: str = "12.8"
@@ -738,6 +739,14 @@ def get_vllm_port() -> int | None:
738739
"VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv(
739740
"VLLM_VIDEO_LOADER_BACKEND", "opencv"
740741
),
742+
# Media connector implementation.
743+
# - "http": Default connector that supports fetching media via HTTP.
744+
#
745+
# Custom implementations can be registered
746+
# via `@MEDIA_CONNECTOR_REGISTRY.register("my_custom_media_connector")` and
747+
# imported at runtime.
748+
# If a non-existing backend is used, an AssertionError will be thrown.
749+
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
741750
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
742751
# Default is 4 GiB per API process + 4 GiB per engine core process
743752
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),

vllm/multimodal/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.connections import HTTPConnection, global_http_connection
2121
from vllm.logger import init_logger
2222
from vllm.utils.jsontree import json_map_leaves
23+
from vllm.utils.registry import ExtensionManager
2324

2425
from .audio import AudioMediaIO
2526
from .base import MediaIO
@@ -46,7 +47,10 @@
4647

4748
_M = TypeVar("_M")
4849

50+
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
4951

52+
53+
@MEDIA_CONNECTOR_REGISTRY.register("http")
5054
class MediaConnector:
5155
def __init__(
5256
self,

vllm/multimodal/video.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from vllm import envs
1616
from vllm.logger import init_logger
17+
from vllm.utils.registry import ExtensionManager
1718

1819
from .base import MediaIO
1920
from .image import ImageMediaIO
@@ -63,25 +64,7 @@ def load_bytes(
6364
raise NotImplementedError
6465

6566

66-
class VideoLoaderRegistry:
67-
def __init__(self) -> None:
68-
self.name2class: dict[str, type] = {}
69-
70-
def register(self, name: str):
71-
def wrap(cls_to_register):
72-
self.name2class[name] = cls_to_register
73-
return cls_to_register
74-
75-
return wrap
76-
77-
@staticmethod
78-
def load(cls_name: str) -> VideoLoader:
79-
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
80-
assert cls is not None, f"VideoLoader class {cls_name} not found"
81-
return cls()
82-
83-
84-
VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
67+
VIDEO_LOADER_REGISTRY = ExtensionManager()
8568

8669

8770
@VIDEO_LOADER_REGISTRY.register("opencv")

vllm/utils/registry.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Any
4+
5+
6+
class ExtensionManager:
7+
"""
8+
A registry for managing pluggable extension classes.
9+
10+
This class provides a simple mechanism to register and instantiate
11+
extension classes by name. It is commonly used to implement plugin
12+
systems where different implementations can be swapped at runtime.
13+
14+
Examples:
15+
Basic usage with a registry instance:
16+
17+
>>> FOO_REGISTRY = ExtensionManager()
18+
>>> @FOO_REGISTRY.register("my_foo_impl")
19+
... class MyFooImpl(Foo):
20+
... def __init__(self, value):
21+
... self.value = value
22+
>>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123)
23+
24+
"""
25+
26+
def __init__(self) -> None:
27+
"""
28+
Initialize an empty extension registry.
29+
"""
30+
self.name2class: dict[str, type] = {}
31+
32+
def register(self, name: str):
33+
"""
34+
Decorator to register a class with the given name.
35+
"""
36+
37+
def wrap(cls_to_register):
38+
self.name2class[name] = cls_to_register
39+
return cls_to_register
40+
41+
return wrap
42+
43+
def load(self, cls_name: str, *args, **kwargs) -> Any:
44+
"""
45+
Instantiate and return a registered extension class by name.
46+
"""
47+
cls = self.name2class.get(cls_name)
48+
assert cls is not None, f"Extension class {cls_name} not found"
49+
return cls(*args, **kwargs)

0 commit comments

Comments
 (0)