2020 BatchFeature )
2121from transformers .models .auto .auto_factory import _BaseAutoModelClass
2222
23+ from tests .models .utils import (TokensTextLogprobs ,
24+ TokensTextLogprobsPromptLogprobs )
2325from vllm import LLM , SamplingParams
2426from vllm .assets .image import ImageAsset
2527from vllm .assets .video import VideoAsset
3335 to_enc_dec_tuple_list , zip_enc_dec_prompts )
3436from vllm .logger import init_logger
3537from vllm .outputs import RequestOutput
36- from vllm .sequence import SampleLogprobs
3738from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , cuda_device_count_stateless ,
3839 identity , is_cpu )
3940
@@ -469,7 +470,7 @@ def generate_greedy_logprobs_limit(
469470 audios : Optional [PromptAudioInput ] = None ,
470471 videos : Optional [List [np .ndarray ]] = None ,
471472 ** kwargs : Any ,
472- ) -> List [Tuple [ List [ int ], str , List [ Dict [ int , float ]]] ]:
473+ ) -> List [TokensTextLogprobs ]:
473474 all_logprobs : List [List [Dict [int , float ]]] = []
474475 all_output_ids : List [List [int ]] = []
475476 all_output_strs : List [str ] = []
@@ -525,7 +526,7 @@ def generate_encoder_decoder_greedy_logprobs_limit(
525526 max_tokens : int ,
526527 num_logprobs : int ,
527528 ** kwargs : Any ,
528- ) -> List [Tuple [ List [ int ], str , List [ Dict [ int , float ]]] ]:
529+ ) -> List [TokensTextLogprobs ]:
529530 '''
530531 Greedy logprobs generation for vLLM encoder/decoder models
531532 '''
@@ -653,14 +654,16 @@ def generate(
653654 @staticmethod
654655 def _final_steps_generate_w_logprobs (
655656 req_outputs : List [RequestOutput ],
656- ) -> List [Tuple [ List [ int ], str , Optional [ SampleLogprobs ]] ]:
657- outputs : List [Tuple [ List [ int ], str , Optional [ SampleLogprobs ]] ] = []
657+ ) -> List [TokensTextLogprobsPromptLogprobs ]:
658+ outputs : List [TokensTextLogprobsPromptLogprobs ] = []
658659 for req_output in req_outputs :
660+ assert len (req_output .outputs ) > 0
659661 for sample in req_output .outputs :
660662 output_str = sample .text
661663 output_ids = list (sample .token_ids )
662664 output_logprobs = sample .logprobs
663- outputs .append ((output_ids , output_str , output_logprobs ))
665+ outputs .append ((output_ids , output_str , output_logprobs ,
666+ req_output .prompt_logprobs ))
664667 return outputs
665668
666669 def generate_w_logprobs (
@@ -670,7 +673,8 @@ def generate_w_logprobs(
670673 images : Optional [PromptImageInput ] = None ,
671674 audios : Optional [PromptAudioInput ] = None ,
672675 videos : Optional [PromptVideoInput ] = None ,
673- ) -> List [Tuple [List [int ], str , Optional [SampleLogprobs ]]]:
676+ ) -> Union [List [TokensTextLogprobs ],
677+ List [TokensTextLogprobsPromptLogprobs ]]:
674678 assert sampling_params .logprobs is not None
675679
676680 if images is not None :
@@ -695,21 +699,33 @@ def generate_w_logprobs(
695699
696700 req_outputs = self .model .generate (inputs ,
697701 sampling_params = sampling_params )
698- return self ._final_steps_generate_w_logprobs (req_outputs )
702+
703+ toks_str_logsprobs_prompt_logprobs = (
704+ self ._final_steps_generate_w_logprobs (req_outputs ))
705+ # Omit prompt logprobs if not required by sampling params
706+ return ([x [0 :- 1 ] for x in toks_str_logsprobs_prompt_logprobs ]
707+ if sampling_params .prompt_logprobs is None else
708+ toks_str_logsprobs_prompt_logprobs )
699709
700710 def generate_encoder_decoder_w_logprobs (
701711 self ,
702712 encoder_decoder_prompts : List [ExplicitEncoderDecoderPrompt [str , str ]],
703713 sampling_params : SamplingParams ,
704- ) -> List [Tuple [List [int ], str , Optional [SampleLogprobs ]]]:
714+ ) -> Union [List [TokensTextLogprobs ],
715+ List [TokensTextLogprobsPromptLogprobs ]]:
705716 '''
706717 Logprobs generation for vLLM encoder/decoder models
707718 '''
708719
709720 assert sampling_params .logprobs is not None
710721 req_outputs = self .model .generate (encoder_decoder_prompts ,
711722 sampling_params = sampling_params )
712- return self ._final_steps_generate_w_logprobs (req_outputs )
723+ toks_str_logsprobs_prompt_logprobs = (
724+ self ._final_steps_generate_w_logprobs (req_outputs ))
725+ # Omit prompt logprobs if not required by sampling params
726+ return ([x [0 :- 1 ] for x in toks_str_logsprobs_prompt_logprobs ]
727+ if sampling_params .prompt_logprobs is None else
728+ toks_str_logsprobs_prompt_logprobs )
713729
714730 def generate_greedy (
715731 self ,
@@ -727,44 +743,48 @@ def generate_greedy_logprobs(
727743 prompts : List [str ],
728744 max_tokens : int ,
729745 num_logprobs : int ,
746+ num_prompt_logprobs : Optional [int ] = None ,
730747 images : Optional [PromptImageInput ] = None ,
731748 audios : Optional [PromptAudioInput ] = None ,
732749 videos : Optional [PromptVideoInput ] = None ,
733750 stop_token_ids : Optional [List [int ]] = None ,
734- ) -> List [Tuple [List [int ], str , Optional [SampleLogprobs ]]]:
735- greedy_logprobs_params = SamplingParams (temperature = 0.0 ,
736- max_tokens = max_tokens ,
737- logprobs = num_logprobs ,
738- stop_token_ids = stop_token_ids )
739- outputs = self .generate_w_logprobs (prompts ,
740- greedy_logprobs_params ,
741- images = images ,
742- audios = audios ,
743- videos = videos )
744-
745- return [(output_ids , output_str , output_logprobs )
746- for output_ids , output_str , output_logprobs in outputs ]
751+ ) -> Union [List [TokensTextLogprobs ],
752+ List [TokensTextLogprobsPromptLogprobs ]]:
753+ greedy_logprobs_params = SamplingParams (
754+ temperature = 0.0 ,
755+ max_tokens = max_tokens ,
756+ logprobs = num_logprobs ,
757+ prompt_logprobs = (num_prompt_logprobs ),
758+ stop_token_ids = stop_token_ids )
759+
760+ return self .generate_w_logprobs (prompts ,
761+ greedy_logprobs_params ,
762+ images = images ,
763+ audios = audios ,
764+ videos = videos )
747765
748766 def generate_encoder_decoder_greedy_logprobs (
749767 self ,
750768 encoder_decoder_prompts : List [ExplicitEncoderDecoderPrompt [str , str ]],
751769 max_tokens : int ,
752770 num_logprobs : int ,
753- ) -> List [Tuple [List [int ], str , Optional [SampleLogprobs ]]]:
754- greedy_logprobs_params = SamplingParams (temperature = 0.0 ,
755- use_beam_search = False ,
756- max_tokens = max_tokens ,
757- logprobs = num_logprobs )
771+ num_prompt_logprobs : Optional [int ] = None ,
772+ ) -> Union [List [TokensTextLogprobs ],
773+ List [TokensTextLogprobsPromptLogprobs ]]:
774+ greedy_logprobs_params = SamplingParams (
775+ temperature = 0.0 ,
776+ use_beam_search = False ,
777+ max_tokens = max_tokens ,
778+ logprobs = num_logprobs ,
779+ prompt_logprobs = (num_prompt_logprobs ),
780+ )
758781 '''
759782 Greedy logprobs generation for vLLM encoder/decoder models
760783 '''
761784
762- outputs = self .generate_encoder_decoder_w_logprobs (
785+ return self .generate_encoder_decoder_w_logprobs (
763786 encoder_decoder_prompts , greedy_logprobs_params )
764787
765- return [(output_ids , output_str , output_logprobs )
766- for output_ids , output_str , output_logprobs in outputs ]
767-
768788 def generate_beam_search (
769789 self ,
770790 prompts : List [str ],
0 commit comments