Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer

logger = init_logger(__name__)
Expand Down Expand Up @@ -1070,7 +1071,19 @@ def apply_hf_chat_template(
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None and tokenizer.chat_template is None:
if chat_template is None:
chat_template = tokenizer.chat_template

# FIXME: Temporary workaround for
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31
if chat_template is None:
try:
processor = cached_get_processor(tokenizer.name_or_path)
chat_template = processor.chat_template
except Exception:
pass

if chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
Expand Down
15 changes: 8 additions & 7 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict
from typing import Literal, Optional, Set, Tuple, TypedDict

import torch
import torch.nn as nn
Expand All @@ -31,8 +31,7 @@
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
Expand All @@ -58,10 +57,12 @@ class FuyuImagePatchInputs(TypedDict):
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""

patches_per_image: List[int]
patches_per_image: list[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`.
The number of total patches for each image in the batch.

This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""


Expand Down Expand Up @@ -317,7 +318,7 @@ def _parse_and_validate_image_input(
return None

def _process_image_input(
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from torch import Tensor
from typing_extensions import TypeIs
from typing_extensions import Self, TypeIs

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -451,7 +451,7 @@ class SupportsQuant:
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
quant_config: Optional[QuantizationConfig] = None

def __new__(cls, *args, **kwargs) -> "SupportsQuant":
def __new__(cls, *args, **kwargs) -> Self:
instance = super().__new__(cls)
quant_config = cls._find_quant_config(*args, **kwargs)
if quant_config is not None:
Expand Down
136 changes: 64 additions & 72 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, TypeVar, Union, cast)
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union, cast)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -39,8 +39,7 @@

from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel,
get_pixtral_hf_image_feature_grid_size)
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
Expand All @@ -49,31 +48,42 @@

class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
pixel_values: torch.Tensor
"""
Shape: `(batch_size * num_images, num_channels, height, width)`

Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""

feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]

class PixtralHFImagePixelInputs(TypedDict):
type: Literal["pixel_values_pixtral"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`

Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""

feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.

Shape: `(batch_size, num_crops, num_patch)`
"""

embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_embeds)`
"""

num_crops: torch.Tensor
num_crops: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""


Expand All @@ -85,27 +95,9 @@ class LlavaImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""

feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.

Shape: `(batch_size, num_crops, num_patch)`
"""

embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_embeds)`
"""

num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`"""


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
LlavaImageEmbeddingInputs]


class LlavaMultiModalProjector(nn.Module):
Expand Down Expand Up @@ -357,13 +349,15 @@ def _call_hf_processor(
]

hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)

tile_sizes = [
get_pixtral_hf_image_feature_grid_size(
hf_config.vision_config,
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2])
for pixel_value in processed_outputs["pixel_values"]
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_crops = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
Expand Down Expand Up @@ -411,13 +405,13 @@ def _get_prompt_updates(

vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)

def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)

ncols, nrows = get_pixtral_hf_image_feature_grid_size(
vision_config,
ncols, nrows = encoder_info.get_patch_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
Expand Down Expand Up @@ -512,7 +506,7 @@ def init_vision_tower_for_llava(
*,
require_post_norm: Optional[bool] = None,
prefix: str = "",
):
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
vision_config = hf_config.vision_config

# Initialize the vision tower only up to the deepest required feature layer
Expand Down Expand Up @@ -627,57 +621,52 @@ def _parse_and_validate_image_input(
if pixel_values is None and image_embeds is None:
return None

feat_is_patch = kwargs.pop("feat_is_patch", None)
if feat_is_patch is not None and not isinstance(
feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")

embed_is_patch = kwargs.pop("embed_is_patch", None)
if embed_is_patch is not None and not isinstance(
embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")

num_crops = kwargs.pop("num_crops", None)
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")

if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

if self.config.vision_config.model_type == "pixtral":
return LlavaImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values),
feat_is_patch = kwargs.pop("feat_is_patch")
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")

embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")

num_crops = kwargs.pop("num_crops")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")

return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)

if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

if self.config.vision_config.model_type == "pixtral":
raise ValueError("Pixtral-HF does not support image_embeds.")

return LlavaImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)

raise AssertionError("This line should be unreachable.")
Expand All @@ -696,7 +685,7 @@ def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: torch.Tensor,
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor:

# NOTE: we skip the step to select the vision feature layer since
Expand All @@ -708,17 +697,20 @@ def _image_pixels_to_features(
strategy=self.config.vision_feature_select_strategy,
)

def _process_image_pixels(self,
inputs: LlavaImagePixelInputs) -> torch.Tensor:
def _process_image_pixels(
self,
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
) -> torch.Tensor:
assert self.vision_tower is not None

pixel_values = inputs["data"]
pixel_values = inputs["pixel_values"]

return self._image_pixels_to_features(self.vision_tower, pixel_values)

def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:

def _process_image_input(
self,
image_input: LlavaImageInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]

Expand Down Expand Up @@ -783,11 +775,11 @@ def get_multimodal_embeddings(
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

vision_embeddings = self._process_image_input(image_input)

if kwargs.get("v0_path", False) or \
image_input.get("feat_is_patch") is None or \
image_input.get("embed_is_patch") is None:
if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values_pixtral"):
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings

Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Expand Down Expand Up @@ -315,7 +315,8 @@ def _parse_and_validate_image_input(

return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(flatten_bn(pixel_values)),
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
)
Expand Down Expand Up @@ -434,7 +435,7 @@ def _process_image_pixels(
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None

pixel_values = inputs["data"]
pixel_values = inputs["pixel_values"]

if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape
Expand Down
Loading