1313
1414from .audio import resample_audio
1515from .inputs import (AudioItem , HfAudioItem , HfImageItem , HfVideoItem ,
16- ImageItem , ModalityData , MultiModalDataDict ,
17- NestedTensors , VideoItem )
16+ ImageItem , ModalityData , MultiModalDataDict , VideoItem )
1817
1918_T = TypeVar ("_T" )
2019_I = TypeVar ("_I" )
2120
2221
2322class ModalityDataItems (ABC , Generic [_T , _I ]):
23+ """
24+ Represents data items for a modality in :class:`MultiModalDataItems`.
25+ """
2426
2527 def __init__ (self , data : _T , modality : str ) -> None :
2628 super ().__init__ ()
@@ -69,6 +71,7 @@ def get_passthrough_data(self) -> Mapping[str, object]:
6971
7072
7173class ProcessorBatchItems (ModalityDataItems [Sequence [_T ], _T ]):
74+ """Base class for data items that are arranged in a list."""
7275
7376 def get_count (self ) -> int :
7477 return len (self .data )
@@ -83,7 +86,12 @@ def get_passthrough_data(self) -> Mapping[str, object]:
8386 return {}
8487
8588
86- class EmbeddingItems (ModalityDataItems [NestedTensors , torch .Tensor ]):
89+ class EmbeddingItems (ModalityDataItems [Union [torch .Tensor , list [torch .Tensor ]],
90+ torch .Tensor ]):
91+ """
92+ Base class for data items that are expressed as a batched embedding tensor,
93+ or a list of embedding tensors (one per item).
94+ """
8795
8896 def get_count (self ) -> int :
8997 return len (self .data )
@@ -109,7 +117,7 @@ def __init__(self, data: Sequence[HfAudioItem]) -> None:
109117
110118class AudioEmbeddingItems (EmbeddingItems ):
111119
112- def __init__ (self , data : NestedTensors ) -> None :
120+ def __init__ (self , data : Union [ torch . Tensor , list [ torch . Tensor ]] ) -> None :
113121 super ().__init__ (data , "audio" )
114122
115123
@@ -137,7 +145,7 @@ def get_image_size(self, item_idx: int) -> ImageSize:
137145
138146class ImageEmbeddingItems (EmbeddingItems ):
139147
140- def __init__ (self , data : NestedTensors ) -> None :
148+ def __init__ (self , data : Union [ torch . Tensor , list [ torch . Tensor ]] ) -> None :
141149 super ().__init__ (data , "image" )
142150
143151
@@ -163,7 +171,7 @@ def get_frame_size(self, item_idx: int) -> ImageSize:
163171
164172class VideoEmbeddingItems (EmbeddingItems ):
165173
166- def __init__ (self , data : NestedTensors ) -> None :
174+ def __init__ (self , data : Union [ torch . Tensor , list [ torch . Tensor ]] ) -> None :
167175 super ().__init__ (data , "video" )
168176
169177
@@ -172,8 +180,8 @@ def __init__(self, data: NestedTensors) -> None:
172180
173181class MultiModalDataItems (UserDict [str , ModalityDataItems [Any , Any ]]):
174182 """
175- As :class:` MultiModalDataDict`, but normalized such that each entry
176- corresponds to a list.
183+ As :data:`~vllm.multimodal.inputs. MultiModalDataDict`, but normalized
184+ such that each entry corresponds to a list.
177185 """
178186
179187 def get_count (self , modality : str , * , strict : bool = True ) -> int :
@@ -226,7 +234,8 @@ def get_items(
226234
227235class MultiModalDataParser :
228236 """
229- Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
237+ Parses :data:`~vllm.multimodal.inputs.MultiModalDataDict` into
238+ :class:`MultiModalDataItems`.
230239
231240 Args:
232241 target_sr (float, optional): Enables automatic resampling of audio
@@ -238,7 +247,9 @@ def __init__(self, *, target_sr: Optional[float] = None) -> None:
238247
239248 self .target_sr = target_sr
240249
241- def _is_embeddings (self , data : object ) -> TypeGuard [NestedTensors ]:
250+ def _is_embeddings (
251+ self , data : object
252+ ) -> TypeGuard [Union [torch .Tensor , list [torch .Tensor ]]]:
242253 if isinstance (data , torch .Tensor ):
243254 return data .ndim == 3
244255 if is_list_of (data , torch .Tensor ):
0 commit comments