Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info

logger = init_logger(__name__)

Expand Down Expand Up @@ -67,6 +68,9 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(PaliGemmaConfig)

def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

Expand All @@ -78,9 +82,8 @@ def get_mm_max_tokens_per_item(
return {"image": self.get_num_image_tokens()}

def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
return get_max_siglip_image_tokens(vision_config)
vision_encoder_info = self.get_vision_encoder_info()
return vision_encoder_info.get_max_image_tokens()


class PaliGemmaDummyInputsBuilder(
Expand Down Expand Up @@ -173,8 +176,10 @@ def apply(
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
prompt_token_ids = mm_inputs["prompt_token_ids"]

tokenizer = self.info.get_tokenizer()
Expand Down
75 changes: 4 additions & 71 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Iterable, Optional, Set, Tuple, Union

import torch
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig

Expand All @@ -20,74 +19,10 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import SequenceData

from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs


def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return image_size // patch_size


def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_siglip_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length


def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
return get_siglip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)


def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
return get_siglip_image_feature_size(hf_config)


def dummy_seq_data_for_siglip(
hf_config: SiglipVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
mm_key: str = "image",
):
if image_feature_size_override is None:
image_feature_size = get_siglip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override

return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}


def dummy_image_for_siglip(
hf_config: SiglipVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override

image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}


class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):

def get_num_image_tokens(
Expand All @@ -96,10 +31,10 @@ def get_num_image_tokens(
image_width: int,
image_height: int,
) -> int:
return get_siglip_image_feature_size(self.vision_config)
return self.get_patch_grid_length()**2

def get_max_image_tokens(self) -> int:
return get_max_siglip_image_tokens(self.vision_config)
return self.get_patch_grid_length()**2

def get_image_size(self) -> int:
return self.vision_config.image_size
Expand All @@ -108,10 +43,8 @@ def get_patch_size(self) -> int:
return self.vision_config.patch_size

def get_patch_grid_length(self) -> int:
return get_siglip_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
image_size, patch_size = self.get_image_size(), self.get_patch_size()
return image_size // patch_size


# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
Expand Down