Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Further update the model as follows:
image_features = self.vision_encoder(image_input)
return self.multi_modal_projector(image_features)

def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

# Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs)
Expand All @@ -61,7 +62,7 @@ Further update the model as follows:
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:

# `get_input_embeddings` should already be implemented for the language
Expand Down
10 changes: 4 additions & 6 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
Expand All @@ -35,7 +34,7 @@
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsMultiModal, SupportsQuant
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix,
Expand Down Expand Up @@ -607,8 +606,7 @@ def _process_image_input(
return self.multi_modal_projector(image_outputs, image_attn_mask)

def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand All @@ -618,7 +616,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
Expand Down
10 changes: 4 additions & 6 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
Expand All @@ -25,7 +24,7 @@
from vllm.sequence import IntermediateTensors

from .blip import BlipVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)

Expand Down Expand Up @@ -629,8 +628,7 @@ def _process_image_input(self,
return self.language_projection(query_output)

def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand All @@ -640,7 +638,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
Expand Down
10 changes: 4 additions & 6 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
Expand Down Expand Up @@ -986,8 +985,7 @@ def _parse_and_validate_image_input(
)

def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand All @@ -1000,7 +998,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:

inputs_embeds = self.model.get_input_embeddings(input_ids)
Expand Down
7 changes: 3 additions & 4 deletions vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
Expand Down Expand Up @@ -605,8 +605,7 @@ def _process_image_input(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)

def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand All @@ -616,7 +615,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
BartParallelLMHead,
BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo,
Expand All @@ -30,7 +30,8 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsMultiModal, SupportsV0Only
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsV0Only)
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings


Expand Down Expand Up @@ -1037,8 +1038,7 @@ def _process_image_input(
return self._encode_image(pixel_values)

def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand All @@ -1048,7 +1048,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
Expand Down
9 changes: 4 additions & 5 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import List, Literal, Optional, Set, Tuple, TypedDict

import torch
import torch.nn as nn
Expand All @@ -41,7 +41,7 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)

Expand Down Expand Up @@ -327,8 +327,7 @@ def _process_image_input(
return vision_embeddings_flat.split(patches_per_image, dim=0)

def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand All @@ -338,7 +337,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
Expand All @@ -24,7 +23,7 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
Expand Down Expand Up @@ -481,7 +480,8 @@ def _process_image_input(
)
return self.multi_modal_projector(vision_outputs)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand All @@ -491,7 +491,7 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
if multimodal_embeddings is None:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature,
Expand All @@ -39,7 +39,8 @@
from vllm.transformers_utils.configs import ChatGLMConfig

from .chatglm import ChatGLMBaseModel, ChatGLMModel
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import flatten_bn, merge_multimodal_embeddings


Expand Down Expand Up @@ -596,8 +597,7 @@ def _process_image_input(
return self.transformer.vision(pixel_values)

def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand All @@ -608,7 +608,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)

Expand Down
7 changes: 3 additions & 4 deletions vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsLoRA, SupportsMultiModal
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
Expand Down Expand Up @@ -617,8 +617,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.sampler = get_sampler()

def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self.model._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand All @@ -628,7 +627,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
Expand Down
Loading