11from typing import (TYPE_CHECKING , Any , Dict , Generic , Iterable , List ,
2- Optional , Tuple , Union )
2+ Optional , Tuple , Union , cast )
33
44from typing_extensions import NotRequired , TypedDict , TypeVar
55
@@ -51,7 +51,7 @@ class TokensPrompt(TypedDict):
5151
5252SingletonPrompt = Union [str , TextPrompt , TokensPrompt ]
5353"""
54- Set of possible schemas for a single LLM input :
54+ Set of possible schemas for a single prompt :
5555
5656- A text prompt (:class:`str` or :class:`TextPrompt`)
5757- A tokenized prompt (:class:`TokensPrompt`)
@@ -120,13 +120,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
120120"""
121121
122122
123- class LLMInputs (TypedDict ):
124- """
125- The inputs in :class:`~vllm.LLMEngine` before they are
126- passed to the model executor.
127-
128- This specifies the data required for decoder-only models.
129- """
123+ class TokenInputs (TypedDict ):
124+ """Represents token-based inputs."""
130125 prompt_token_ids : List [int ]
131126 """The token IDs of the prompt."""
132127
@@ -150,7 +145,40 @@ class LLMInputs(TypedDict):
150145 """
151146
152147
153- class EncoderDecoderLLMInputs (LLMInputs ):
148+ def token_inputs (
149+ prompt_token_ids : List [int ],
150+ prompt : Optional [str ] = None ,
151+ multi_modal_data : Optional ["MultiModalDataDict" ] = None ,
152+ mm_processor_kwargs : Optional [Dict [str , Any ]] = None ,
153+ ) -> TokenInputs :
154+ """Construct :class:`TokenInputs` from optional values."""
155+ inputs = TokenInputs (prompt_token_ids = prompt_token_ids )
156+
157+ if prompt is not None :
158+ inputs ["prompt" ] = prompt
159+ if multi_modal_data is not None :
160+ inputs ["multi_modal_data" ] = multi_modal_data
161+ if mm_processor_kwargs is not None :
162+ inputs ["mm_processor_kwargs" ] = mm_processor_kwargs
163+
164+ return inputs
165+
166+
167+ SingletonInputs = TokenInputs
168+ """
169+ A processed :class:`SingletonPrompt` which can be passed to
170+ :class:`vllm.sequence.Sequence`.
171+ """
172+
173+ DecoderOnlyInputs = TokenInputs
174+ """
175+ The inputs in :class:`~vllm.LLMEngine` before they are
176+ passed to the model executor.
177+ This specifies the data required for decoder-only models.
178+ """
179+
180+
181+ class EncoderDecoderInputs (TokenInputs ):
154182 """
155183 The inputs in :class:`~vllm.LLMEngine` before they are
156184 passed to the model executor.
@@ -204,11 +232,12 @@ def zip_enc_dec_prompts(
204232 be zipped with the encoder/decoder prompts.
205233 """
206234 if mm_processor_kwargs is None :
207- mm_processor_kwargs = {}
208- if isinstance (mm_processor_kwargs , Dict ):
235+ mm_processor_kwargs = cast ( Dict [ str , Any ], {})
236+ if isinstance (mm_processor_kwargs , dict ):
209237 return [
210- build_explicit_enc_dec_prompt (encoder_prompt , decoder_prompt ,
211- mm_processor_kwargs )
238+ build_explicit_enc_dec_prompt (
239+ encoder_prompt , decoder_prompt ,
240+ cast (Dict [str , Any ], mm_processor_kwargs ))
212241 for (encoder_prompt ,
213242 decoder_prompt ) in zip (enc_prompts , dec_prompts )
214243 ]
@@ -229,14 +258,31 @@ def to_enc_dec_tuple_list(
229258
230259
231260def __getattr__ (name : str ):
232- if name == "PromptInput" :
233- import warnings
261+ import warnings
234262
263+ if name == "PromptInput" :
235264 msg = ("PromptInput has been renamed to PromptType. "
236265 "The original name will be removed in an upcoming version." )
237266
238267 warnings .warn (DeprecationWarning (msg ), stacklevel = 2 )
239268
240269 return PromptType
241270
271+ if name == "LLMInputs" :
272+ msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
273+ "The original name will be removed in an upcoming version." )
274+
275+ warnings .warn (DeprecationWarning (msg ), stacklevel = 2 )
276+
277+ return DecoderOnlyInputs
278+
279+ if name == "EncoderDecoderLLMInputs" :
280+ msg = (
281+ "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
282+ "The original name will be removed in an upcoming version." )
283+
284+ warnings .warn (DeprecationWarning (msg ), stacklevel = 2 )
285+
286+ return EncoderDecoderInputs
287+
242288 raise AttributeError (f"module { __name__ !r} has no attribute { name !r} " )
0 commit comments