2121from vllm .model_executor .model_loader .weight_utils import (
2222 default_weight_loader , maybe_remap_kv_scale_name )
2323from vllm .multimodal import MULTIMODAL_REGISTRY
24- from vllm .multimodal .inputs import (MultiModalFieldConfig , MultiModalKwargs ,
25- NestedTensors )
24+ from vllm .multimodal .inputs import MultiModalFieldConfig , MultiModalKwargs
2625from vllm .multimodal .parse import MultiModalDataItems
2726from vllm .multimodal .processing import (BaseMultiModalProcessor ,
2827 BaseProcessingInfo , PromptReplacement ,
3534from .idefics2_vision_model import (
3635 Idefics2VisionTransformer as Idefics3VisionTransformer )
3736# yapf: enable
38- from .interfaces import SupportsMultiModal , SupportsQuant
37+ from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsQuant
3938from .llama import LlamaDecoderLayer , LlamaMLP , LlamaModel
4039from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
4140 is_pp_missing_parameter , maybe_prefix ,
@@ -607,8 +606,7 @@ def _process_image_input(
607606 return self .multi_modal_projector (image_outputs , image_attn_mask )
608607
609608 def get_multimodal_embeddings (
610- self , ** kwargs
611- ) -> Union [list [torch .Tensor ], torch .Tensor , tuple [torch .Tensor , ...]]:
609+ self , ** kwargs : object ) -> Optional [MultiModalEmbeddings ]:
612610 image_input = self ._parse_and_validate_image_input (** kwargs )
613611 if image_input is None :
614612 return None
@@ -618,7 +616,7 @@ def get_multimodal_embeddings(
618616 def get_input_embeddings (
619617 self ,
620618 input_ids : torch .Tensor ,
621- multimodal_embeddings : Optional [NestedTensors ] = None ,
619+ multimodal_embeddings : Optional [MultiModalEmbeddings ] = None ,
622620 ) -> torch .Tensor :
623621 inputs_embeds = self .language_model .get_input_embeddings (input_ids )
624622 if multimodal_embeddings is not None :
0 commit comments