diff --git a/tests/conftest.py b/tests/conftest.py index 369acb92cfb9..8acc1f28559e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,7 +57,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import is_list_of, set_default_torch_num_threads +from vllm.utils import set_default_torch_num_threads +from vllm.utils.collections import is_list_of logger = init_logger(__name__) diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index d9c1d53b61c2..c110f5598bee 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -25,7 +25,7 @@ from transformers.video_utils import VideoMetadata from vllm.logprobs import SampleLogprobs -from vllm.utils import is_list_of +from vllm.utils.collections import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets from .types import RunnerOutput diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 9029f09de8c8..00c46082df66 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -35,7 +35,7 @@ from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from vllm.utils import is_list_of +from vllm.utils.collections import is_list_of from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS from ...utils import dummy_hf_overrides diff --git a/tests/utils_/test_collections.py b/tests/utils_/test_collections.py new file mode 100644 index 000000000000..cb96bf2b0d21 --- /dev/null +++ b/tests/utils_/test_collections.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.utils.collections import swap_dict_values + + +@pytest.mark.parametrize( + "obj,key1,key2", + [ + # Tests for both keys exist + ({1: "a", 2: "b"}, 1, 2), + # Tests for one key does not exist + ({1: "a", 2: "b"}, 1, 3), + # Tests for both keys do not exist + ({1: "a", 2: "b"}, 3, 4), + ], +) +def test_swap_dict_values(obj, key1, key2): + original_obj = obj.copy() + + swap_dict_values(obj, key1, key2) + + if key1 in original_obj: + assert obj[key2] == original_obj[key1] + else: + assert key2 not in obj + if key2 in original_obj: + assert obj[key1] == original_obj[key2] + else: + assert key1 not in obj diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 3bc4d3536d58..efc83c0a31b8 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -38,7 +38,6 @@ sha256, split_host_port, split_zmq_path, - swap_dict_values, unique_filepath, ) @@ -516,30 +515,6 @@ def build_ctx(): _ = placeholder_attr.module -@pytest.mark.parametrize( - "obj,key1,key2", - [ - # Tests for both keys exist - ({1: "a", 2: "b"}, 1, 2), - # Tests for one key does not exist - ({1: "a", 2: "b"}, 1, 3), - # Tests for both keys do not exist - ({1: "a", 2: "b"}, 3, 4), - ], -) -def test_swap_dict_values(obj, key1, key2): - original_obj = obj.copy() - swap_dict_values(obj, key1, key2) - if key1 in original_obj: - assert obj[key2] == original_obj[key1] - else: - assert key2 not in obj - if key2 in original_obj: - assert obj[key1] == original_obj[key2] - else: - assert key1 not in obj - - def test_model_specification( parser_with_config, cli_config_file, cli_config_file_with_model ): diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5883b92acd99..177d2b9174c5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -75,7 +75,8 @@ get_cached_tokenizer, ) from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, Device, as_iter, is_list_of +from vllm.utils import Counter, Device +from vllm.utils.collections import as_iter, is_list_of from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.sample.logits_processor import LogitsProcessor diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5dc7f7859226..b5b314e15ad8 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -70,7 +70,7 @@ truncate_tool_call_ids, validate_request_params, ) -from vllm.utils import as_list +from vllm.utils.collections import as_list logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b60330f71d39..5c4199020574 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -34,8 +34,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import as_list from vllm.utils.asyncio import merge_async_iterators +from vllm.utils.collections import as_list logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7540a2e5d5cf..2e3129cbeb8e 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -39,8 +39,8 @@ RequestOutput, ) from vllm.pooling_params import PoolingParams -from vllm.utils import chunk_list from vllm.utils.asyncio import merge_async_iterators +from vllm.utils.collections import chunk_list logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 041e47c1c797..ffcde8c3024c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -90,13 +90,14 @@ log_tracing_disabled_warning, ) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import is_list_of, random_uuid +from vllm.utils import random_uuid from vllm.utils.asyncio import ( AsyncMicrobatchTokenizer, collect_from_async_generator, make_async, merge_async_iterators, ) +from vllm.utils.collections import is_list_of from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 3327ac99134f..c7363e442cdd 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -12,7 +12,8 @@ ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import import_from_path, is_list_of +from vllm.utils import import_from_path +from vllm.utils.collections import is_list_of logger = init_logger(__name__) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 5cfef7f5b6d9..c84fc098f002 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from typing_extensions import TypeIs -from vllm.utils import is_list_of +from vllm.utils.collections import is_list_of from .data import ( EmbedsPrompt, diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 50548d2e1afa..f48fad559efd 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -17,7 +17,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import LazyDict +from vllm.utils.collections import LazyDict logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f65c6156d040..0a1800590bda 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -28,7 +28,7 @@ RowvLLMParameter, ) from vllm.transformers_utils.config import get_safetensors_params_metadata -from vllm.utils import is_list_of +from vllm.utils.collections import is_list_of if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b22c3c125ead..191c7a6388e8 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -57,7 +57,7 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.transformers_utils.config import get_safetensors_params_metadata -from vllm.utils import is_list_of +from vllm.utils.collections import is_list_of logger = init_logger(__name__) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 094a7e73b3aa..d9e1523b048a 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -49,7 +49,7 @@ ) from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from vllm.utils import is_list_of +from vllm.utils.collections import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 17732f8a5490..56acb3ddf124 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -33,7 +33,7 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of +from vllm.utils.collections import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index ef2bbac75654..1c2e9042a6b1 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -86,7 +86,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import flatten_2d_lists +from vllm.utils.collections import flatten_2d_lists from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f114aae25c51..6b2dafbf1555 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -79,7 +79,7 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of +from vllm.utils.collections import is_list_of from .interfaces import ( MultiModalEmbeddings, diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index dec2e0acab6b..9eab33b45e8a 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -22,7 +22,8 @@ import numpy as np from typing_extensions import NotRequired, TypeVar, deprecated -from vllm.utils import LazyLoader, full_groupby, is_list_of +from vllm.utils import LazyLoader +from vllm.utils.collections import full_groupby, is_list_of from vllm.utils.jsontree import json_map_leaves if TYPE_CHECKING: diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 748355309521..71e577f0c0ad 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -19,7 +19,8 @@ import torch from typing_extensions import assert_never -from vllm.utils import LazyLoader, is_list_of +from vllm.utils import LazyLoader +from vllm.utils.collections import is_list_of from .audio import AudioResampler from .inputs import ( @@ -364,7 +365,7 @@ def _is_embeddings( if isinstance(data, torch.Tensor): return data.ndim == 3 if is_list_of(data, torch.Tensor): - return data[0].ndim == 2 + return data[0].ndim == 2 # type: ignore[index] return False @@ -422,6 +423,7 @@ def _parse_audio_data( if self._is_embeddings(data): return AudioEmbeddingItems(data) + data_items: list[AudioItem] if ( is_list_of(data, float) or isinstance(data, (np.ndarray, torch.Tensor)) @@ -432,7 +434,7 @@ def _parse_audio_data( elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] else: - data_items = data + data_items = data # type: ignore[assignment] new_audios = list[np.ndarray]() for data_item in data_items: @@ -485,6 +487,7 @@ def _parse_video_data( if self._is_embeddings(data): return VideoEmbeddingItems(data) + data_items: list[VideoItem] if ( is_list_of(data, PILImage.Image) or isinstance(data, (np.ndarray, torch.Tensor)) @@ -496,7 +499,7 @@ def _parse_video_data( elif isinstance(data, tuple) and len(data) == 2: data_items = [data] else: - data_items = data + data_items = data # type: ignore[assignment] new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]() metadata_lst: list[dict[str, Any] | None] = [] diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 9e8bb91baed7..b47e82a19d70 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -25,7 +25,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens -from vllm.utils import flatten_2d_lists, full_groupby +from vllm.utils.collections import flatten_2d_lists, full_groupby from vllm.utils.functools import get_allowed_kwarg_only_overrides from vllm.utils.jsontree import JSONTree, json_map_leaves @@ -484,8 +484,11 @@ def modality(self) -> str: ... def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: - """Convenience function to apply [`full_groupby`][vllm.utils.full_groupby] - based on modality.""" + """ + Convenience function to apply + [`full_groupby`][vllm.utils.collections.full_groupby] + based on modality. + """ return full_groupby(values, key=lambda x: x.modality) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 66d0bb7458c0..1e1b4acdb871 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -9,7 +9,7 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config -from vllm.utils import ClassRegistry +from vllm.utils.collections import ClassRegistry from .cache import BaseMultiModalProcessorCache from .processing import ( diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index b85216f43fad..cbebca09e7b8 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -8,7 +8,8 @@ from typing import TYPE_CHECKING, Any from vllm.logger import init_logger -from vllm.utils import import_from_path, is_list_of +from vllm.utils import import_from_path +from vllm.utils.collections import is_list_of if TYPE_CHECKING: from vllm.entrypoints.openai.protocol import ( diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 99a9225cb6a4..94cf6d0ff153 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -37,29 +37,19 @@ RawDescriptionHelpFormatter, _ArgumentGroup, ) -from collections import UserDict, defaultdict +from collections import defaultdict from collections.abc import ( Callable, Collection, Generator, - Hashable, - Iterable, Iterator, - Mapping, Sequence, ) from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Literal, - TextIO, - TypeVar, -) +from typing import TYPE_CHECKING, Any, TextIO, TypeVar from urllib.parse import urlparse from uuid import uuid4 @@ -78,7 +68,7 @@ from packaging import version from packaging.version import Version from torch.library import Library -from typing_extensions import Never, TypeIs, assert_never +from typing_extensions import Never import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger @@ -170,9 +160,6 @@ def set_default_torch_num_threads(num_threads: int): T = TypeVar("T") U = TypeVar("U") -_K = TypeVar("_K", bound=Hashable) -_V = TypeVar("_V") - class Device(enum.Enum): GPU = enum.auto() @@ -421,12 +408,6 @@ def update_environment_variables(envs: dict[str, str]): os.environ[k] = v -def chunk_list(lst: list[T], chunk_size: int): - """Yield successive chunk_size chunks from lst.""" - for i in range(0, len(lst), chunk_size): - yield lst[i : i + chunk_size] - - def cdiv(a: int, b: int) -> int: """Ceiling division.""" return -(a // -b) @@ -743,53 +724,6 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): ) -def as_list(maybe_list: Iterable[T]) -> list[T]: - """Convert iterable to list, unless it's already a list.""" - return maybe_list if isinstance(maybe_list, list) else list(maybe_list) - - -def as_iter(obj: T | Iterable[T]) -> Iterable[T]: - if isinstance(obj, str) or not isinstance(obj, Iterable): - return [obj] # type: ignore[list-item] - return obj - - -# `collections` helpers -def is_list_of( - value: object, - typ: type[T] | tuple[type[T], ...], - *, - check: Literal["first", "all"] = "first", -) -> TypeIs[list[T]]: - if not isinstance(value, list): - return False - - if check == "first": - return len(value) == 0 or isinstance(value[0], typ) - elif check == "all": - return all(isinstance(v, typ) for v in value) - - assert_never(check) - - -def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]: - """Flatten a list of lists to a single list.""" - return [item for sublist in lists for item in sublist] - - -def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): - """ - Unlike [`itertools.groupby`][], groups are not broken by - non-contiguous data. - """ - groups = defaultdict[_K, list[_V]](list) - - for value in values: - groups[key(value)].append(value) - - return groups.items() - - # TODO: This function can be removed if transformer_modules classes are # serialized by value when communicating between processes def init_cached_hf_modules() -> None: @@ -1578,50 +1512,6 @@ def value(self): return self._value -# Adapted from: https://stackoverflow.com/a/47212782/5082708 -class LazyDict(Mapping[str, T], Generic[T]): - def __init__(self, factory: dict[str, Callable[[], T]]): - self._factory = factory - self._dict: dict[str, T] = {} - - def __getitem__(self, key: str) -> T: - if key not in self._dict: - if key not in self._factory: - raise KeyError(key) - self._dict[key] = self._factory[key]() - return self._dict[key] - - def __setitem__(self, key: str, value: Callable[[], T]): - self._factory[key] = value - - def __iter__(self): - return iter(self._factory) - - def __len__(self): - return len(self._factory) - - -class ClassRegistry(UserDict[type[T], _V]): - def __getitem__(self, key: type[T]) -> _V: - for cls in key.mro(): - if cls in self.data: - return self.data[cls] - - raise KeyError(key) - - def __contains__(self, key: object) -> bool: - return self.contains(key) - - def contains(self, key: object, *, strict: bool = False) -> bool: - if not isinstance(key, type): - return False - - if strict: - return key in self.data - - return any(cls in self.data for cls in key.mro()) - - def weak_ref_tensor(tensor: Any) -> Any: """ Create a weak reference to a tensor. @@ -2588,22 +2478,6 @@ def __dir__(self) -> list[str]: return dir(self._module) -def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: - """ - Helper function to swap values for two keys - """ - v1 = obj.get(key1) - v2 = obj.get(key2) - if v1 is not None: - obj[key2] = v1 - else: - obj.pop(key2, None) - if v2 is not None: - obj[key1] = v2 - else: - obj.pop(key1, None) - - @contextlib.contextmanager def cprofile_context(save_file: str | None = None): """Run a cprofile diff --git a/vllm/utils/collections.py b/vllm/utils/collections.py new file mode 100644 index 000000000000..57271311828c --- /dev/null +++ b/vllm/utils/collections.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers that are applied to collections. + +This is similar in concept to the `collections` module. +""" + +from collections import UserDict, defaultdict +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping +from typing import Generic, Literal, TypeVar + +from typing_extensions import TypeIs, assert_never + +T = TypeVar("T") +U = TypeVar("U") + +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") + + +class ClassRegistry(UserDict[type[T], _V]): + """ + A registry that acts like a dictionary but searches for other classes + in the MRO if the original class is not found. + """ + + def __getitem__(self, key: type[T]) -> _V: + for cls in key.mro(): + if cls in self.data: + return self.data[cls] + + raise KeyError(key) + + def __contains__(self, key: object) -> bool: + return self.contains(key) + + def contains(self, key: object, *, strict: bool = False) -> bool: + if not isinstance(key, type): + return False + + if strict: + return key in self.data + + return any(cls in self.data for cls in key.mro()) + + +class LazyDict(Mapping[str, T], Generic[T]): + """ + Evaluates dictionary items only when they are accessed. + + Adapted from: https://stackoverflow.com/a/47212782/5082708 + """ + + def __init__(self, factory: dict[str, Callable[[], T]]): + self._factory = factory + self._dict: dict[str, T] = {} + + def __getitem__(self, key: str) -> T: + if key not in self._dict: + if key not in self._factory: + raise KeyError(key) + self._dict[key] = self._factory[key]() + return self._dict[key] + + def __setitem__(self, key: str, value: Callable[[], T]): + self._factory[key] = value + + def __iter__(self): + return iter(self._factory) + + def __len__(self): + return len(self._factory) + + +def as_list(maybe_list: Iterable[T]) -> list[T]: + """Convert iterable to list, unless it's already a list.""" + return maybe_list if isinstance(maybe_list, list) else list(maybe_list) + + +def as_iter(obj: T | Iterable[T]) -> Iterable[T]: + if isinstance(obj, str) or not isinstance(obj, Iterable): + return [obj] # type: ignore[list-item] + return obj + + +def is_list_of( + value: object, + typ: type[T] | tuple[type[T], ...], + *, + check: Literal["first", "all"] = "first", +) -> TypeIs[list[T]]: + if not isinstance(value, list): + return False + + if check == "first": + return len(value) == 0 or isinstance(value[0], typ) + elif check == "all": + return all(isinstance(v, typ) for v in value) + + assert_never(check) + + +def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]: + """Yield successive chunk_size chunks from lst.""" + for i in range(0, len(lst), chunk_size): + yield lst[i : i + chunk_size] + + +def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]: + """Flatten a list of lists to a single list.""" + return [item for sublist in lists for item in sublist] + + +def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): + """ + Unlike [`itertools.groupby`][], groups are not broken by + non-contiguous data. + """ + groups = defaultdict[_K, list[_V]](list) + + for value in values: + groups[key(value)].append(value) + + return groups.items() + + +def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: + """Swap values between two keys.""" + v1 = obj.get(key1) + v2 = obj.get(key2) + if v1 is not None: + obj[key2] = v1 + else: + obj.pop(key2, None) + if v2 is not None: + obj[key1] = v2 + else: + obj.pop(key1, None) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 8a52fcef9e73..584956c1f0eb 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -29,8 +29,9 @@ from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, as_list, cdiv +from vllm.utils import Device, cdiv from vllm.utils.asyncio import cancel_task_threadsafe +from vllm.utils.collections import as_list from vllm.utils.functools import deprecate_kwargs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b8751546f767..4a9cbeaea08c 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -12,7 +12,8 @@ from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.utils.collections import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import ( diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 80b62066c8df..efd107b097c1 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -9,7 +9,8 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingType -from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.utils.collections import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState