Skip to content

Commit 4c47a67

Browse files
DarkLight1337richardsliu
authored andcommitted
[Misc] Clean up type annotation for SupportsMultiModal (vllm-project#14794)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Richard Liu <[email protected]>
1 parent a4f2d9e commit 4c47a67

27 files changed

+121
-141
lines changed

docs/source/contributing/model/multimodal.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ Further update the model as follows:
3434
image_features = self.vision_encoder(image_input)
3535
return self.multi_modal_projector(image_features)
3636

37-
def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]:
37+
def get_multimodal_embeddings(
38+
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
3839

3940
# Validate the multimodal input keyword arguments
4041
image_input = self._parse_and_validate_image_input(**kwargs)
@@ -61,7 +62,7 @@ Further update the model as follows:
6162
def get_input_embeddings(
6263
self,
6364
input_ids: torch.Tensor,
64-
multimodal_embeddings: Optional[NestedTensors] = None,
65+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
6566
) -> torch.Tensor:
6667

6768
# `get_input_embeddings` should already be implemented for the language

tests/distributed/test_pipeline_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def iter_params(self, model_id: str):
214214
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
215215
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
216216
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
217-
"microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(),
217+
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
218218
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
219219
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
220220
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
@@ -237,7 +237,7 @@ def iter_params(self, model_id: str):
237237
"BAAI/bge-multilingual-gemma2",
238238
# [MULTIMODAL GENERATION]
239239
"OpenGVLab/InternVL2-1B",
240-
"microsoft/Phi-3-vision-128k-instruct",
240+
"microsoft/Phi-3.5-vision-instruct",
241241
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
242242
# [LANGUAGE GENERATION - HYBRID ARCH]
243243
"ai21labs/Jamba-tiny-dev",

vllm/model_executor/models/aria.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
from vllm.model_executor.model_loader.weight_utils import (
2222
default_weight_loader, maybe_remap_kv_scale_name)
2323
from vllm.multimodal import MULTIMODAL_REGISTRY
24-
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
25-
NestedTensors)
24+
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
2625
from vllm.multimodal.parse import MultiModalDataItems
2726
from vllm.multimodal.processing import (BaseMultiModalProcessor,
2827
BaseProcessingInfo, PromptReplacement,
@@ -35,7 +34,7 @@
3534
from .idefics2_vision_model import (
3635
Idefics2VisionTransformer as Idefics3VisionTransformer)
3736
# yapf: enable
38-
from .interfaces import SupportsMultiModal, SupportsQuant
37+
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
3938
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
4039
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
4140
is_pp_missing_parameter, maybe_prefix,
@@ -607,8 +606,7 @@ def _process_image_input(
607606
return self.multi_modal_projector(image_outputs, image_attn_mask)
608607

609608
def get_multimodal_embeddings(
610-
self, **kwargs
611-
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
609+
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
612610
image_input = self._parse_and_validate_image_input(**kwargs)
613611
if image_input is None:
614612
return None
@@ -618,7 +616,7 @@ def get_multimodal_embeddings(
618616
def get_input_embeddings(
619617
self,
620618
input_ids: torch.Tensor,
621-
multimodal_embeddings: Optional[NestedTensors] = None,
619+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
622620
) -> torch.Tensor:
623621
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
624622
if multimodal_embeddings is not None:

vllm/model_executor/models/blip2.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
1616
from vllm.model_executor.sampling_metadata import SamplingMetadata
1717
from vllm.multimodal import MULTIMODAL_REGISTRY
18-
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
19-
NestedTensors)
18+
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
2019
from vllm.multimodal.parse import MultiModalDataItems
2120
from vllm.multimodal.processing import (BaseMultiModalProcessor,
2221
BaseProcessingInfo, PromptIndexTargets,
@@ -25,7 +24,7 @@
2524
from vllm.sequence import IntermediateTensors
2625

2726
from .blip import BlipVisionModel
28-
from .interfaces import SupportsMultiModal, SupportsPP
27+
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
2928
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
3029
maybe_prefix, merge_multimodal_embeddings)
3130

@@ -629,8 +628,7 @@ def _process_image_input(self,
629628
return self.language_projection(query_output)
630629

631630
def get_multimodal_embeddings(
632-
self, **kwargs
633-
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
631+
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
634632
image_input = self._parse_and_validate_image_input(**kwargs)
635633
if image_input is None:
636634
return None
@@ -640,7 +638,7 @@ def get_multimodal_embeddings(
640638
def get_input_embeddings(
641639
self,
642640
input_ids: torch.Tensor,
643-
multimodal_embeddings: Optional[NestedTensors] = None,
641+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
644642
) -> torch.Tensor:
645643
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
646644
if multimodal_embeddings is not None:

vllm/model_executor/models/chameleon.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,15 @@
3030
from vllm.model_executor.sampling_metadata import SamplingMetadata
3131
from vllm.model_executor.utils import set_weight_attrs
3232
from vllm.multimodal import MULTIMODAL_REGISTRY
33-
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
34-
NestedTensors)
33+
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
3534
from vllm.multimodal.parse import MultiModalDataItems
3635
from vllm.multimodal.processing import (BaseMultiModalProcessor,
3736
BaseProcessingInfo, PromptReplacement,
3837
PromptUpdate, PromptUpdateDetails)
3938
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
4039
from vllm.sequence import IntermediateTensors
4140

42-
from .interfaces import SupportsMultiModal, SupportsPP
41+
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4342
from .utils import (is_pp_missing_parameter,
4443
make_empty_intermediate_tensors_factory, make_layers,
4544
maybe_prefix, merge_multimodal_embeddings)
@@ -986,8 +985,7 @@ def _parse_and_validate_image_input(
986985
)
987986

988987
def get_multimodal_embeddings(
989-
self, **kwargs
990-
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
988+
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
991989
image_input = self._parse_and_validate_image_input(**kwargs)
992990
if image_input is None:
993991
return None
@@ -1000,7 +998,7 @@ def get_multimodal_embeddings(
1000998
def get_input_embeddings(
1001999
self,
10021000
input_ids: torch.Tensor,
1003-
multimodal_embeddings: Optional[NestedTensors] = None,
1001+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
10041002
) -> torch.Tensor:
10051003

10061004
inputs_embeds = self.model.get_input_embeddings(input_ids)

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
3737
from vllm.utils import is_list_of
3838

39-
from .interfaces import SupportsMultiModal, SupportsPP
39+
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4040
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
4141
init_vllm_registered_model, maybe_prefix,
4242
merge_multimodal_embeddings)
@@ -605,8 +605,7 @@ def _process_image_input(
605605
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
606606

607607
def get_multimodal_embeddings(
608-
self, **kwargs: object
609-
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
608+
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
610609
image_input = self._parse_and_validate_image_input(**kwargs)
611610
if image_input is None:
612611
return None
@@ -616,7 +615,7 @@ def get_multimodal_embeddings(
616615
def get_input_embeddings(
617616
self,
618617
input_ids: torch.Tensor,
619-
multimodal_embeddings: Optional[NestedTensors] = None,
618+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
620619
) -> torch.Tensor:
621620
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
622621
if multimodal_embeddings is not None:

vllm/model_executor/models/florence2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
BartParallelLMHead,
2121
BartScaledWordEmbedding)
2222
from vllm.model_executor.sampling_metadata import SamplingMetadata
23-
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
23+
from vllm.multimodal import MULTIMODAL_REGISTRY
2424
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
2525
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
2626
from vllm.multimodal.processing import (BaseProcessingInfo,
@@ -30,7 +30,8 @@
3030
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3131
from vllm.sequence import IntermediateTensors
3232

33-
from .interfaces import SupportsMultiModal, SupportsV0Only
33+
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
34+
SupportsV0Only)
3435
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
3536

3637

@@ -1037,8 +1038,7 @@ def _process_image_input(
10371038
return self._encode_image(pixel_values)
10381039

10391040
def get_multimodal_embeddings(
1040-
self, **kwargs: object
1041-
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
1041+
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
10421042
image_input = self._parse_and_validate_image_input(**kwargs)
10431043
if image_input is None:
10441044
return None
@@ -1048,7 +1048,7 @@ def get_multimodal_embeddings(
10481048
def get_input_embeddings(
10491049
self,
10501050
input_ids: torch.Tensor,
1051-
multimodal_embeddings: Optional[NestedTensors] = None,
1051+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
10521052
) -> torch.Tensor:
10531053
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
10541054
if multimodal_embeddings is not None:

vllm/model_executor/models/fuyu.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
""" PyTorch Fuyu model."""
1919
import math
2020
from collections.abc import Iterable, Mapping, Sequence
21-
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
21+
from typing import List, Literal, Optional, Set, Tuple, TypedDict
2222

2323
import torch
2424
import torch.nn as nn
@@ -41,7 +41,7 @@
4141
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
4242
from vllm.sequence import IntermediateTensors
4343

44-
from .interfaces import SupportsMultiModal, SupportsPP
44+
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4545
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
4646
merge_multimodal_embeddings)
4747

@@ -327,8 +327,7 @@ def _process_image_input(
327327
return vision_embeddings_flat.split(patches_per_image, dim=0)
328328

329329
def get_multimodal_embeddings(
330-
self, **kwargs
331-
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
330+
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
332331
image_input = self._parse_and_validate_image_input(**kwargs)
333332
if image_input is None:
334333
return None
@@ -338,7 +337,7 @@ def get_multimodal_embeddings(
338337
def get_input_embeddings(
339338
self,
340339
input_ids: torch.Tensor,
341-
multimodal_embeddings: Optional[NestedTensors] = None,
340+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
342341
) -> torch.Tensor:
343342
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
344343
if multimodal_embeddings is not None:

vllm/model_executor/models/gemma3_mm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from vllm.model_executor.layers.sampler import SamplerOutput
1515
from vllm.model_executor.sampling_metadata import SamplingMetadata
1616
from vllm.multimodal import MULTIMODAL_REGISTRY
17-
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
18-
NestedTensors)
17+
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
1918
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
2019
MultiModalDataItems)
2120
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -24,7 +23,7 @@
2423
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
2524
from vllm.sequence import IntermediateTensors
2625

27-
from .interfaces import SupportsMultiModal, SupportsPP
26+
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
2827
from .siglip import SiglipVisionModel
2928
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
3029
maybe_prefix, merge_multimodal_embeddings)
@@ -481,7 +480,8 @@ def _process_image_input(
481480
)
482481
return self.multi_modal_projector(vision_outputs)
483482

484-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
483+
def get_multimodal_embeddings(
484+
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
485485
image_input = self._parse_and_validate_image_input(**kwargs)
486486
if image_input is None:
487487
return None
@@ -491,7 +491,7 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
491491
def get_input_embeddings(
492492
self,
493493
input_ids: torch.Tensor,
494-
multimodal_embeddings: Optional[NestedTensors] = None,
494+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
495495
) -> torch.Tensor:
496496
if multimodal_embeddings is None:
497497
inputs_embeds = self.language_model.get_input_embeddings(input_ids)

vllm/model_executor/models/glm4v.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.model_executor.layers.quantization import QuantizationConfig
2929
from vllm.model_executor.models.module_mapping import MultiModelKeys
3030
from vllm.multimodal import MULTIMODAL_REGISTRY
31-
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
31+
from vllm.multimodal.inputs import MultiModalKwargs
3232
from vllm.multimodal.parse import MultiModalDataItems
3333
from vllm.multimodal.processing import (BaseMultiModalProcessor,
3434
BaseProcessingInfo, BatchFeature,
@@ -39,7 +39,8 @@
3939
from vllm.transformers_utils.configs import ChatGLMConfig
4040

4141
from .chatglm import ChatGLMBaseModel, ChatGLMModel
42-
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
42+
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
43+
SupportsMultiModal, SupportsPP)
4344
from .utils import flatten_bn, merge_multimodal_embeddings
4445

4546

@@ -596,8 +597,7 @@ def _process_image_input(
596597
return self.transformer.vision(pixel_values)
597598

598599
def get_multimodal_embeddings(
599-
self, **kwargs
600-
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
600+
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
601601
image_input = self._parse_and_validate_image_input(**kwargs)
602602
if image_input is None:
603603
return None
@@ -608,7 +608,7 @@ def get_multimodal_embeddings(
608608
def get_input_embeddings(
609609
self,
610610
input_ids: torch.Tensor,
611-
multimodal_embeddings: Optional[NestedTensors] = None,
611+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
612612
) -> torch.Tensor:
613613
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
614614

0 commit comments

Comments
 (0)