33from abc import abstractmethod
44from collections .abc import Iterable , Mapping , Sequence
55from functools import cached_property
6- from typing import (Final , List , Literal , Optional , Protocol , Set , Tuple ,
7- TypedDict , TypeVar , Union , cast )
6+ from typing import (Final , Literal , Optional , Protocol , Set , Tuple , TypedDict ,
7+ TypeVar , Union , cast )
88
99import torch
1010import torch .nn as nn
3939
4040from .clip import CLIPVisionModel
4141from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
42- from .pixtral import (PixtralHFVisionModel ,
43- get_pixtral_hf_image_feature_grid_size )
42+ from .pixtral import PixtralHFEncoderInfo , PixtralHFVisionModel
4443from .siglip import SiglipVisionModel
4544from .utils import (AutoWeightsLoader , flatten_bn , init_vllm_registered_model ,
4645 maybe_prefix , merge_multimodal_embeddings )
4948
5049class LlavaImagePixelInputs (TypedDict ):
5150 type : Literal ["pixel_values" ]
52- data : Union [ torch .Tensor , List [ torch . Tensor ]]
51+ pixel_values : torch .Tensor
5352 """
5453 Shape: `(batch_size * num_images, num_channels, height, width)`
5554
5655 Note that `height` or `width` may be different per batch and image,
5756 in which case the data is passed as a list instead of a batched tensor.
5857 """
5958
60- feat_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
59+
60+ class PixtralHFImagePixelInputs (TypedDict ):
61+ type : Literal ["pixel_values_pixtral" ]
62+ pixel_values : Union [torch .Tensor , list [torch .Tensor ]]
63+ """
64+ Shape: `(batch_size * num_images, num_channels, height, width)`
65+
66+ Note that `height` or `width` may be different per batch and image,
67+ in which case the data is passed as a list instead of a batched tensor.
68+ """
69+
70+ feat_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
6171 """
6272 A boolean mask indicating which image features correspond
6373 to patch tokens.
6474
6575 Shape: `(batch_size, num_crops, num_patch)`
6676 """
6777
68- embed_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
78+ embed_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
6979 """
7080 A boolean mask indicating which image embeddings correspond
7181 to patch tokens.
7282
7383 Shape: `(batch_size, num_embeds)`
7484 """
7585
76- num_crops : torch .Tensor
86+ num_crops : Union [ torch .Tensor , list [ torch . Tensor ]]
7787 """Shape: `(batch_size, num_images)`"""
7888
7989
@@ -85,27 +95,9 @@ class LlavaImageEmbeddingInputs(TypedDict):
8595 `hidden_size` must match the hidden size of language model backbone.
8696 """
8797
88- feat_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
89- """
90- A boolean mask indicating which image features correspond
91- to patch tokens.
92-
93- Shape: `(batch_size, num_crops, num_patch)`
94- """
95-
96- embed_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
97- """
98- A boolean mask indicating which image embeddings correspond
99- to patch tokens.
100-
101- Shape: `(batch_size, num_embeds)`
102- """
103-
104- num_crops : torch .Tensor
105- """Shape: `(batch_size, num_images)`"""
106-
10798
108- LlavaImageInputs = Union [LlavaImagePixelInputs , LlavaImageEmbeddingInputs ]
99+ LlavaImageInputs = Union [LlavaImagePixelInputs , PixtralHFImagePixelInputs ,
100+ LlavaImageEmbeddingInputs ]
109101
110102
111103class LlavaMultiModalProjector (nn .Module ):
@@ -357,13 +349,15 @@ def _call_hf_processor(
357349 ]
358350
359351 hf_config = self .info .get_hf_config ()
352+ vision_config = hf_config .vision_config
353+ assert isinstance (vision_config , PixtralVisionConfig )
354+ encoder_info = PixtralHFEncoderInfo (vision_config )
360355
361356 tile_sizes = [
362- get_pixtral_hf_image_feature_grid_size (
363- hf_config .vision_config ,
357+ encoder_info .get_patch_grid_size (
364358 image_width = pixel_value .shape [- 1 ],
365- image_height = pixel_value .shape [- 2 ])
366- for pixel_value in processed_outputs ["pixel_values" ]
359+ image_height = pixel_value .shape [- 2 ],
360+ ) for pixel_value in processed_outputs ["pixel_values" ]
367361 ]
368362 num_crops = torch .tensor ([(ncols + 1 ) * nrows
369363 for ncols , nrows in tile_sizes ])
@@ -411,13 +405,13 @@ def _get_prompt_updates(
411405
412406 vision_config = hf_config .vision_config
413407 assert isinstance (vision_config , PixtralVisionConfig )
408+ encoder_info = PixtralHFEncoderInfo (vision_config )
414409
415410 def get_replacement (item_idx : int ):
416411 images = mm_items .get_items ("image" , ImageProcessorItems )
417412 image_size = images .get_image_size (item_idx )
418413
419- ncols , nrows = get_pixtral_hf_image_feature_grid_size (
420- vision_config ,
414+ ncols , nrows = encoder_info .get_patch_grid_size (
421415 image_width = image_size .width ,
422416 image_height = image_size .height ,
423417 )
@@ -512,7 +506,7 @@ def init_vision_tower_for_llava(
512506 * ,
513507 require_post_norm : Optional [bool ] = None ,
514508 prefix : str = "" ,
515- ):
509+ ) -> Union [ CLIPVisionModel , SiglipVisionModel , PixtralHFVisionModel ] :
516510 vision_config = hf_config .vision_config
517511
518512 # Initialize the vision tower only up to the deepest required feature layer
@@ -627,57 +621,52 @@ def _parse_and_validate_image_input(
627621 if pixel_values is None and image_embeds is None :
628622 return None
629623
630- feat_is_patch = kwargs .pop ("feat_is_patch" , None )
631- if feat_is_patch is not None and not isinstance (
632- feat_is_patch , (torch .Tensor , list )):
633- raise ValueError ("Incorrect type of feat_is_patch. "
634- f"Got type: { type (feat_is_patch )} " )
635-
636- embed_is_patch = kwargs .pop ("embed_is_patch" , None )
637- if embed_is_patch is not None and not isinstance (
638- embed_is_patch , (torch .Tensor , list )):
639- raise ValueError ("Incorrect type of embed_is_patch. "
640- f"Got type: { type (embed_is_patch )} " )
641-
642- num_crops = kwargs .pop ("num_crops" , None )
643- if num_crops is not None and not isinstance (num_crops , torch .Tensor ):
644- raise ValueError ("Incorrect type of num_crops. "
645- f"Got type: { type (num_crops )} " )
646-
647624 if pixel_values is not None :
648625 if not isinstance (pixel_values , (torch .Tensor , list )):
649626 raise ValueError ("Incorrect type of pixel values. "
650627 f"Got type: { type (pixel_values )} " )
651628
652629 if self .config .vision_config .model_type == "pixtral" :
653- return LlavaImagePixelInputs (
654- type = "pixel_values" ,
655- data = flatten_bn (pixel_values ),
630+ feat_is_patch = kwargs .pop ("feat_is_patch" )
631+ if not isinstance (feat_is_patch , (torch .Tensor , list )):
632+ raise ValueError ("Incorrect type of feat_is_patch. "
633+ f"Got type: { type (feat_is_patch )} " )
634+
635+ embed_is_patch = kwargs .pop ("embed_is_patch" )
636+ if not isinstance (embed_is_patch , (torch .Tensor , list )):
637+ raise ValueError ("Incorrect type of embed_is_patch. "
638+ f"Got type: { type (embed_is_patch )} " )
639+
640+ num_crops = kwargs .pop ("num_crops" )
641+ if not isinstance (num_crops , (torch .Tensor , list )):
642+ raise ValueError ("Incorrect type of num_crops. "
643+ f"Got type: { type (num_crops )} " )
644+
645+ return PixtralHFImagePixelInputs (
646+ type = "pixel_values_pixtral" ,
647+ pixel_values = flatten_bn (pixel_values ),
656648 feat_is_patch = feat_is_patch ,
657649 embed_is_patch = embed_is_patch ,
658650 num_crops = num_crops ,
659651 )
660652
661653 return LlavaImagePixelInputs (
662654 type = "pixel_values" ,
663- data = self ._validate_pixel_values (
655+ pixel_values = self ._validate_pixel_values (
664656 flatten_bn (pixel_values , concat = True )),
665- feat_is_patch = feat_is_patch ,
666- embed_is_patch = embed_is_patch ,
667- num_crops = num_crops ,
668657 )
669658
670659 if image_embeds is not None :
671660 if not isinstance (image_embeds , (torch .Tensor , list )):
672661 raise ValueError ("Incorrect type of image embeddings. "
673662 f"Got type: { type (image_embeds )} " )
674663
664+ if self .config .vision_config .model_type == "pixtral" :
665+ raise ValueError ("Pixtral-HF does not support image_embeds." )
666+
675667 return LlavaImageEmbeddingInputs (
676668 type = "image_embeds" ,
677669 data = flatten_bn (image_embeds , concat = True ),
678- feat_is_patch = feat_is_patch ,
679- embed_is_patch = embed_is_patch ,
680- num_crops = num_crops ,
681670 )
682671
683672 raise AssertionError ("This line should be unreachable." )
@@ -696,7 +685,7 @@ def _image_pixels_to_features(
696685 self ,
697686 vision_tower : Union [CLIPVisionModel , SiglipVisionModel ,
698687 PixtralHFVisionModel ],
699- pixel_values : torch .Tensor ,
688+ pixel_values : Union [ torch .Tensor , list [ torch . Tensor ]] ,
700689 ) -> torch .Tensor :
701690
702691 # NOTE: we skip the step to select the vision feature layer since
@@ -708,17 +697,20 @@ def _image_pixels_to_features(
708697 strategy = self .config .vision_feature_select_strategy ,
709698 )
710699
711- def _process_image_pixels (self ,
712- inputs : LlavaImagePixelInputs ) -> torch .Tensor :
700+ def _process_image_pixels (
701+ self ,
702+ inputs : Union [LlavaImagePixelInputs , PixtralHFImagePixelInputs ],
703+ ) -> torch .Tensor :
713704 assert self .vision_tower is not None
714705
715- pixel_values = inputs ["data " ]
706+ pixel_values = inputs ["pixel_values " ]
716707
717708 return self ._image_pixels_to_features (self .vision_tower , pixel_values )
718709
719- def _process_image_input (self ,
720- image_input : LlavaImageInputs ) -> torch .Tensor :
721-
710+ def _process_image_input (
711+ self ,
712+ image_input : LlavaImageInputs ,
713+ ) -> Union [torch .Tensor , tuple [torch .Tensor , ...]]:
722714 if image_input ["type" ] == "image_embeds" :
723715 return image_input ["data" ]
724716
@@ -783,11 +775,11 @@ def get_multimodal_embeddings(
783775 image_input = self ._parse_and_validate_image_input (** kwargs )
784776 if image_input is None :
785777 return None
778+
786779 vision_embeddings = self ._process_image_input (image_input )
787780
788- if kwargs .get ("v0_path" , False ) or \
789- image_input .get ("feat_is_patch" ) is None or \
790- image_input .get ("embed_is_patch" ) is None :
781+ if (kwargs .get ("v0_path" , False )
782+ or image_input ["type" ] != "pixel_values_pixtral" ):
791783 # The path is used for pixtral (V0 only) and llava (V0/V1)
792784 return vision_embeddings
793785
0 commit comments