55from array import array
66from collections import defaultdict
77from dataclasses import dataclass
8+ from functools import cached_property , reduce
89from typing import TYPE_CHECKING , Any , Callable , Dict , List , Mapping , Optional
910from typing import Sequence as GenericSequence
1011from typing import Set , Tuple , Union , cast
@@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
169170 # It is used to compute mrope_position_ids.
170171 _mrope_position_delta : Optional [int ] = None
171172
173+ @staticmethod
174+ def from_counts (counts_by_token : Mapping [int , int ]) -> "SequenceData" :
175+ if len (counts_by_token ) == 0 :
176+ return SequenceData .from_seqs ([])
177+
178+ arrs = [
179+ array (VLLM_TOKEN_ID_ARRAY_TYPE , [token_id ]) * count
180+ for token_id , count in counts_by_token .items ()
181+ ]
182+
183+ return SequenceData (reduce (array .__add__ , arrs ))
184+
185+ @staticmethod
186+ def from_seqs (
187+ prompt_token_ids : GenericSequence [int ],
188+ output_token_ids : Optional [GenericSequence [int ]] = None ,
189+ ) -> "SequenceData" :
190+ prompt_token_ids_arr = array (VLLM_TOKEN_ID_ARRAY_TYPE ,
191+ prompt_token_ids )
192+
193+ if output_token_ids is None :
194+ return SequenceData (prompt_token_ids_arr )
195+
196+ output_token_ids_arr = array (VLLM_TOKEN_ID_ARRAY_TYPE ,
197+ output_token_ids )
198+
199+ return SequenceData (prompt_token_ids_arr ,
200+ _output_token_ids = output_token_ids_arr )
201+
172202 def __post_init__ (self ) -> None :
173203 assert self ._prompt_token_ids .typecode == "l"
174204 assert self ._output_token_ids .typecode == "l"
@@ -370,8 +400,6 @@ def __init__(
370400 self .lora_request = lora_request
371401 self .prompt_adapter_request = prompt_adapter_request
372402 self .from_decoder_prompt = from_decoder_prompt
373- self ._prompt : Optional [str ] = None
374- self ._prompt_token_ids : Optional [List [int ]] = None
375403
376404 # For decoder-only models, a Sequence is constructed
377405 # from an LLMInputs instance (the `inputs` arg.)
@@ -400,8 +428,7 @@ def __init__(
400428 f"invalid input { inputs } ; did you forget the "
401429 "encoder input prompt fields?" )
402430
403- self .data = SequenceData (
404- array (VLLM_TOKEN_ID_ARRAY_TYPE , self .prompt_token_ids ))
431+ self .data = SequenceData .from_seqs (self .prompt_token_ids )
405432 self .output_logprobs : SampleLogprobs = []
406433 self .output_text = ""
407434
@@ -422,37 +449,23 @@ def __init__(
422449 def n_blocks (self ) -> int :
423450 return (self .get_len () + self .block_size - 1 ) // self .block_size
424451
425- @property
452+ @cached_property
426453 def prompt (self ) -> Optional [str ]:
427- if self ._prompt is not None :
428- # Reuse precomputed prompt string
429- return self ._prompt
430-
431- # Select decoder or encoder input prompt str,
432- # as appropriate
454+ # Select decoder or encoder input prompt str, as appropriate
433455 prompt_key : str = ("prompt"
434456 if self .from_decoder_prompt else "encoder_prompt" )
435457
436- # Cache prompt
437- self ._prompt = cast (Optional [str ], self .inputs .get (prompt_key ))
438- return self ._prompt
458+ return cast (Optional [str ], self .inputs .get (prompt_key ))
439459
440- @property
460+ @cached_property
441461 def prompt_token_ids (self ) -> List [int ]:
442- if self ._prompt_token_ids is not None :
443- # Reuse precomputed prompt token ids
444- return self ._prompt_token_ids
445-
446- # Select decoder or encoder input prompt
447- # token ids, as appropriate
462+ # Select decoder or encoder input prompt token ids, as appropriate
448463 prompt_token_ids_key : str = ("prompt_token_ids"
449464 if self .from_decoder_prompt else
450465 "encoder_prompt_token_ids" )
451466
452467 # Cache computed prompt token ids
453- self ._prompt_token_ids = cast (List [int ],
454- self .inputs .get (prompt_token_ids_key ))
455- return self ._prompt_token_ids
468+ return cast (List [int ], self .inputs .get (prompt_token_ids_key ))
456469
457470 @property
458471 def multi_modal_data (self ) -> "MultiModalDataDict" :
0 commit comments