Skip to content

Commit 9164f69

Browse files
afeldman-nmLeiWang1999
authored andcommitted
[Core] *Prompt* logprobs support in Multi-step (vllm-project#8199)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent 4afc6a5 commit 9164f69

File tree

5 files changed

+300
-59
lines changed

5 files changed

+300
-59
lines changed

tests/conftest.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
BatchFeature)
2121
from transformers.models.auto.auto_factory import _BaseAutoModelClass
2222

23+
from tests.models.utils import (TokensTextLogprobs,
24+
TokensTextLogprobsPromptLogprobs)
2325
from vllm import LLM, SamplingParams
2426
from vllm.assets.image import ImageAsset
2527
from vllm.assets.video import VideoAsset
@@ -33,7 +35,6 @@
3335
to_enc_dec_tuple_list, zip_enc_dec_prompts)
3436
from vllm.logger import init_logger
3537
from vllm.outputs import RequestOutput
36-
from vllm.sequence import SampleLogprobs
3738
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
3839
identity, is_cpu)
3940

@@ -469,7 +470,7 @@ def generate_greedy_logprobs_limit(
469470
audios: Optional[PromptAudioInput] = None,
470471
videos: Optional[List[np.ndarray]] = None,
471472
**kwargs: Any,
472-
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
473+
) -> List[TokensTextLogprobs]:
473474
all_logprobs: List[List[Dict[int, float]]] = []
474475
all_output_ids: List[List[int]] = []
475476
all_output_strs: List[str] = []
@@ -525,7 +526,7 @@ def generate_encoder_decoder_greedy_logprobs_limit(
525526
max_tokens: int,
526527
num_logprobs: int,
527528
**kwargs: Any,
528-
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
529+
) -> List[TokensTextLogprobs]:
529530
'''
530531
Greedy logprobs generation for vLLM encoder/decoder models
531532
'''
@@ -653,14 +654,16 @@ def generate(
653654
@staticmethod
654655
def _final_steps_generate_w_logprobs(
655656
req_outputs: List[RequestOutput],
656-
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
657-
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
657+
) -> List[TokensTextLogprobsPromptLogprobs]:
658+
outputs: List[TokensTextLogprobsPromptLogprobs] = []
658659
for req_output in req_outputs:
660+
assert len(req_output.outputs) > 0
659661
for sample in req_output.outputs:
660662
output_str = sample.text
661663
output_ids = list(sample.token_ids)
662664
output_logprobs = sample.logprobs
663-
outputs.append((output_ids, output_str, output_logprobs))
665+
outputs.append((output_ids, output_str, output_logprobs,
666+
req_output.prompt_logprobs))
664667
return outputs
665668

666669
def generate_w_logprobs(
@@ -670,7 +673,8 @@ def generate_w_logprobs(
670673
images: Optional[PromptImageInput] = None,
671674
audios: Optional[PromptAudioInput] = None,
672675
videos: Optional[PromptVideoInput] = None,
673-
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
676+
) -> Union[List[TokensTextLogprobs],
677+
List[TokensTextLogprobsPromptLogprobs]]:
674678
assert sampling_params.logprobs is not None
675679

676680
if images is not None:
@@ -695,21 +699,33 @@ def generate_w_logprobs(
695699

696700
req_outputs = self.model.generate(inputs,
697701
sampling_params=sampling_params)
698-
return self._final_steps_generate_w_logprobs(req_outputs)
702+
703+
toks_str_logsprobs_prompt_logprobs = (
704+
self._final_steps_generate_w_logprobs(req_outputs))
705+
# Omit prompt logprobs if not required by sampling params
706+
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
707+
if sampling_params.prompt_logprobs is None else
708+
toks_str_logsprobs_prompt_logprobs)
699709

700710
def generate_encoder_decoder_w_logprobs(
701711
self,
702712
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
703713
sampling_params: SamplingParams,
704-
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
714+
) -> Union[List[TokensTextLogprobs],
715+
List[TokensTextLogprobsPromptLogprobs]]:
705716
'''
706717
Logprobs generation for vLLM encoder/decoder models
707718
'''
708719

709720
assert sampling_params.logprobs is not None
710721
req_outputs = self.model.generate(encoder_decoder_prompts,
711722
sampling_params=sampling_params)
712-
return self._final_steps_generate_w_logprobs(req_outputs)
723+
toks_str_logsprobs_prompt_logprobs = (
724+
self._final_steps_generate_w_logprobs(req_outputs))
725+
# Omit prompt logprobs if not required by sampling params
726+
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
727+
if sampling_params.prompt_logprobs is None else
728+
toks_str_logsprobs_prompt_logprobs)
713729

714730
def generate_greedy(
715731
self,
@@ -727,44 +743,48 @@ def generate_greedy_logprobs(
727743
prompts: List[str],
728744
max_tokens: int,
729745
num_logprobs: int,
746+
num_prompt_logprobs: Optional[int] = None,
730747
images: Optional[PromptImageInput] = None,
731748
audios: Optional[PromptAudioInput] = None,
732749
videos: Optional[PromptVideoInput] = None,
733750
stop_token_ids: Optional[List[int]] = None,
734-
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
735-
greedy_logprobs_params = SamplingParams(temperature=0.0,
736-
max_tokens=max_tokens,
737-
logprobs=num_logprobs,
738-
stop_token_ids=stop_token_ids)
739-
outputs = self.generate_w_logprobs(prompts,
740-
greedy_logprobs_params,
741-
images=images,
742-
audios=audios,
743-
videos=videos)
744-
745-
return [(output_ids, output_str, output_logprobs)
746-
for output_ids, output_str, output_logprobs in outputs]
751+
) -> Union[List[TokensTextLogprobs],
752+
List[TokensTextLogprobsPromptLogprobs]]:
753+
greedy_logprobs_params = SamplingParams(
754+
temperature=0.0,
755+
max_tokens=max_tokens,
756+
logprobs=num_logprobs,
757+
prompt_logprobs=(num_prompt_logprobs),
758+
stop_token_ids=stop_token_ids)
759+
760+
return self.generate_w_logprobs(prompts,
761+
greedy_logprobs_params,
762+
images=images,
763+
audios=audios,
764+
videos=videos)
747765

748766
def generate_encoder_decoder_greedy_logprobs(
749767
self,
750768
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
751769
max_tokens: int,
752770
num_logprobs: int,
753-
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
754-
greedy_logprobs_params = SamplingParams(temperature=0.0,
755-
use_beam_search=False,
756-
max_tokens=max_tokens,
757-
logprobs=num_logprobs)
771+
num_prompt_logprobs: Optional[int] = None,
772+
) -> Union[List[TokensTextLogprobs],
773+
List[TokensTextLogprobsPromptLogprobs]]:
774+
greedy_logprobs_params = SamplingParams(
775+
temperature=0.0,
776+
use_beam_search=False,
777+
max_tokens=max_tokens,
778+
logprobs=num_logprobs,
779+
prompt_logprobs=(num_prompt_logprobs),
780+
)
758781
'''
759782
Greedy logprobs generation for vLLM encoder/decoder models
760783
'''
761784

762-
outputs = self.generate_encoder_decoder_w_logprobs(
785+
return self.generate_encoder_decoder_w_logprobs(
763786
encoder_decoder_prompts, greedy_logprobs_params)
764787

765-
return [(output_ids, output_str, output_logprobs)
766-
for output_ids, output_str, output_logprobs in outputs]
767-
768788
def generate_beam_search(
769789
self,
770790
prompts: List[str],

tests/models/utils.py

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
from typing import Dict, List, Optional, Sequence, Tuple, Union
33

4-
from vllm.sequence import Logprob, SampleLogprobs
4+
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
55

66
TokensText = Tuple[List[int], str]
77

@@ -34,20 +34,47 @@ def check_outputs_equal(
3434
assert output_ids_0 == output_ids_1, fail_msg
3535

3636

37+
# Representation of generated sequence as a tuple of
38+
# * Token ID list
39+
# * String
40+
# * List of top sample logprobs for each sampled token
41+
#
42+
# Assumes prompt logprobs were not requested.
3743
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
3844
float]],
3945
SampleLogprobs]]]
4046

41-
# Allow for tokens to be represented as str's rather than IDs
47+
# Allow for tokens to be represented as str's rather than IDs;
48+
# tuple of
49+
# * Token string representations list
50+
# * String
51+
# * Optional list of top sample logprobs for each sampled token
52+
#
53+
# Assumes prompt logprobs were not requested.
4254
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
4355
List[Dict[str,
4456
Logprob]]]]]
4557

58+
# Representation of generated sequence as a tuple of
59+
# * Token ID list
60+
# * String
61+
# * Optional list of top sample logprobs for each sampled token
62+
# * Optional list of top prompt logprobs for each prompt token
63+
#
64+
# Allows prompt logprobs to be requested.
65+
TokensTextLogprobsPromptLogprobs = Tuple[
66+
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
67+
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
68+
4669

4770
def check_logprobs_close(
4871
*,
49-
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
50-
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
72+
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
73+
TokensTextLogprobsPromptLogprobs,
74+
TextTextLogprobs]],
75+
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
76+
TokensTextLogprobsPromptLogprobs,
77+
TextTextLogprobs]],
5178
name_0: str,
5279
name_1: str,
5380
num_outputs_0_skip_tokens: int = 0,
@@ -57,6 +84,18 @@ def check_logprobs_close(
5784
"""Compare the logprobs of two sequences generated by different models,
5885
which should be similar but not necessarily equal.
5986
87+
How sample logprobs are compared:
88+
* `always_check_logprobs == True`: set of highest-logprob token ids
89+
must match between seq0 and seq1 at all sampled token offsets
90+
* `always_check_logprobs == False`: highest-logprob token ids are
91+
only compared at sampled token offsets for which generated token
92+
ids don't match
93+
94+
Prompt logprobs must be provided either for both input sequences, or
95+
for neither. If prompt logprobs are provided, then highest-logprob
96+
prompt token ids must match between seq0 and seq1 at all prompt token
97+
offsets.
98+
6099
Args:
61100
outputs_0_lst: First sequence to compare
62101
outputs_0_lst: Second sequence to compare
@@ -78,8 +117,65 @@ def check_logprobs_close(
78117
for prompt_idx, (outputs_0,
79118
outputs_1) in enumerate(zip(outputs_0_lst,
80119
outputs_1_lst)):
81-
output_ids_0, output_str_0, logprobs_0 = outputs_0
82-
output_ids_1, output_str_1, logprobs_1 = outputs_1
120+
assert len(outputs_0) == len(outputs_1)
121+
if len(outputs_0) == 3:
122+
assert len(outputs_1) == 3
123+
# Break out tokens, text & sample logprobs
124+
# (prompt logprobs were not provided)
125+
output_ids_0, output_str_0, logprobs_0 = outputs_0
126+
output_ids_1, output_str_1, logprobs_1 = outputs_1
127+
elif len(outputs_0) == 4:
128+
assert len(outputs_1) == 4
129+
# Break out tokens, text, sample logprobs & prompt logprobs
130+
(
131+
output_ids_0,
132+
output_str_0,
133+
logprobs_0,
134+
prompt_logprobs_0,
135+
) = outputs_0
136+
(
137+
output_ids_1,
138+
output_str_1,
139+
logprobs_1,
140+
prompt_logprobs_1,
141+
) = outputs_1
142+
143+
# Test prompt logprobs closeness
144+
if (prompt_logprobs_0 is not None
145+
and prompt_logprobs_1 is not None):
146+
# Both sequences' prompt logprobs lists are not `None``
147+
# (although individual list elements may be `None`);
148+
# for each token's logprobs:
149+
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
150+
zip(prompt_logprobs_0, prompt_logprobs_1)):
151+
fail_msg = (
152+
f"Prompt logprobs test:"
153+
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
154+
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
155+
156+
if logprobs_elem_0 is None:
157+
# If the seq 0 token's logprobs are `None`,
158+
# the seq 1 token's logprobs must be `None`
159+
assert logprobs_elem_1 is None, fail_msg
160+
else:
161+
# If the seq 0 token's logprobs are not `None`,
162+
# the seq 1 token's logprobs must not be `None`
163+
assert logprobs_elem_1 is not None, fail_msg
164+
# Logprobs check: top-k token choices must be the same
165+
assert (set(logprobs_elem_0.keys()) == set(
166+
logprobs_elem_1.keys())), fail_msg
167+
else:
168+
# Both sequence logprobs lists must be `None`
169+
fail_msg = (f"Prompt logprobs test:"
170+
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
171+
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
172+
173+
assert (prompt_logprobs_0 is None
174+
and prompt_logprobs_1 is None), fail_msg
175+
else:
176+
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
177+
f"{len(outputs_0)} elements were provided: "
178+
f"{outputs_0}")
83179

84180
if logprobs_0 is None:
85181
logprobs_0 = [None] * len(output_ids_0)

0 commit comments

Comments
 (0)