Skip to content

Commit eb3e6a0

Browse files
qthequartermastermanDarkLight1337
authored andcommitted
[CORE] Prompt Embeddings Support for v1 Engine (vllm-project#24278)
Signed-off-by: Andrew Sansom <[email protected]> Signed-off-by: Andrew Sansom <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Signed-off-by: charlifu <[email protected]>
1 parent 2f3e391 commit eb3e6a0

File tree

20 files changed

+304
-75
lines changed

20 files changed

+304
-75
lines changed

tests/basic_correctness/test_basic_correctness.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ def test_models(
7676
model_executor: str,
7777
enable_prompt_embeds: bool,
7878
) -> None:
79-
80-
if enable_prompt_embeds and envs.is_set(
81-
"VLLM_USE_V1") and envs.VLLM_USE_V1:
82-
pytest.skip("enable_prompt_embeds is not supported in v1.")
83-
8479
if not envs.VLLM_USE_V1:
8580
if async_scheduling:
8681
pytest.skip("async_scheduling only supported in v1.")
@@ -164,11 +159,6 @@ def test_models_distributed(
164159
extra_env: dict[str, str],
165160
enable_prompt_embeds: bool,
166161
) -> None:
167-
168-
if enable_prompt_embeds and envs.is_set(
169-
"VLLM_USE_V1") and envs.VLLM_USE_V1:
170-
pytest.skip("enable_prompt_embeds is not supported in v1.")
171-
172162
if test_suite != TARGET_TEST_SUITE:
173163
pytest.skip(f"Skip test for {test_suite}")
174164

tests/entrypoints/openai/test_completion_with_prompt_embeds.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def default_server_args() -> list[str]:
3636
"--enforce-eager",
3737
# Prompt Embeds server args
3838
"--enable-prompt-embeds",
39-
"--no-enable-chunked-prefill",
4039
]
4140

4241

tests/models/language/generation/test_common.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
125125
# in parts of the operators
126126
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
127127

128-
# Note: can be removed when
129-
# https://github.com/vllm-project/vllm/pull/24278 finished
130-
if current_platform.is_cpu() and use_prompt_embeds:
131-
pytest.skip("Skipping use_prompt_embeds=True with "
132-
"V1-only CPU backend.")
133-
134128
with hf_runner(model) as hf_model:
135129
hf_outputs = hf_model.generate_greedy_logprobs_limit(
136130
example_prompts, max_tokens, num_logprobs)

vllm/engine/arg_utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,12 +1513,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
15131513
recommend_to_remove=False)
15141514
return False
15151515

1516-
# No text embedding inputs so far.
1517-
if self.enable_prompt_embeds:
1518-
_raise_or_fallback(feature_name="--enable-prompt-embeds",
1519-
recommend_to_remove=False)
1520-
return False
1521-
15221516
# No Mamba or Encoder-Decoder so far.
15231517
if not model_config.is_v1_compatible:
15241518
_raise_or_fallback(feature_name=model_config.architectures,
@@ -1651,6 +1645,13 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None:
16511645
"models in V0 and has been disabled.")
16521646
self.enable_prefix_caching = False
16531647

1648+
if self.enable_prompt_embeds:
1649+
logger.warning(
1650+
"--enable-prompt-embeds and --enable-prefix-caching "
1651+
"are not supported together in V0. Prefix caching has "
1652+
"been disabled.")
1653+
self.enable_prefix_caching = False
1654+
16541655
# Set max_num_seqs to 256 for VLLM_V0.
16551656
if self.max_num_seqs is None:
16561657
self.max_num_seqs = 256
@@ -1664,6 +1665,17 @@ def _set_default_args_v1(self, usage_context: UsageContext,
16641665
# For pooling tasks the default is False
16651666
if model_config.runner_type != "pooling":
16661667
self.enable_chunked_prefill = True
1668+
1669+
# TODO: When prefix caching supports prompt embeds inputs, this
1670+
# check can be removed.
1671+
if (self.enable_prompt_embeds
1672+
and self.enable_prefix_caching is not False):
1673+
logger.warning(
1674+
"--enable-prompt-embeds and --enable-prefix-caching "
1675+
"are not supported together in V1. Prefix caching has "
1676+
"been disabled.")
1677+
self.enable_prefix_caching = False
1678+
16671679
if self.enable_prefix_caching is None:
16681680
self.enable_prefix_caching = True
16691681
else:

vllm/entrypoints/openai/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,6 @@ class CompletionRequest(OpenAIBaseModel):
973973
# https://platform.openai.com/docs/api-reference/completions/create
974974
model: Optional[str] = None
975975
prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
976-
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
977976
best_of: Optional[int] = None
978977
echo: Optional[bool] = False
979978
frequency_penalty: Optional[float] = 0.0
@@ -1009,6 +1008,7 @@ class CompletionRequest(OpenAIBaseModel):
10091008
# --8<-- [end:completion-sampling-params]
10101009

10111010
# --8<-- [start:completion-extra-params]
1011+
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
10121012
add_special_tokens: bool = Field(
10131013
default=True,
10141014
description=(

vllm/utils/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3443,3 +3443,30 @@ def decorate_logs(process_name: Optional[str] = None) -> None:
34433443
pid = os.getpid()
34443444
_add_prefix(sys.stdout, process_name, pid)
34453445
_add_prefix(sys.stderr, process_name, pid)
3446+
3447+
3448+
def length_from_prompt_token_ids_or_embeds(
3449+
prompt_token_ids: Optional[list[int]],
3450+
prompt_embeds: Optional[torch.Tensor],
3451+
) -> int:
3452+
"""Calculate the request length (in number of tokens) give either
3453+
prompt_token_ids or prompt_embeds.
3454+
"""
3455+
prompt_token_len = None if prompt_token_ids is None else len(
3456+
prompt_token_ids)
3457+
prompt_embeds_len = \
3458+
None if prompt_embeds is None else len(prompt_embeds)
3459+
3460+
if prompt_token_len is None:
3461+
if prompt_embeds_len is None:
3462+
raise ValueError(
3463+
"Neither prompt_token_ids nor prompt_embeds were defined.")
3464+
return prompt_embeds_len
3465+
else:
3466+
if (prompt_embeds_len is not None
3467+
and prompt_embeds_len != prompt_token_len):
3468+
raise ValueError(
3469+
"Prompt token ids and prompt embeds had different lengths"
3470+
f" prompt_token_ids={prompt_token_len}"
3471+
f" prompt_embeds={prompt_embeds_len}")
3472+
return prompt_token_len

vllm/v1/core/sched/output.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
if TYPE_CHECKING:
1212
import numpy as np
1313
import numpy.typing as npt
14+
import torch
1415

1516
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1617
KVConnectorMetadata)
@@ -26,13 +27,14 @@
2627
class NewRequestData:
2728

2829
req_id: str
29-
prompt_token_ids: list[int]
30+
prompt_token_ids: Optional[list[int]]
3031
mm_features: list[MultiModalFeatureSpec]
3132
sampling_params: Optional[SamplingParams]
3233
pooling_params: Optional[PoolingParams]
3334
block_ids: tuple[list[int], ...]
3435
num_computed_tokens: int
3536
lora_request: Optional[LoRARequest]
37+
prompt_embeds: Optional[torch.Tensor] = None
3638

3739
@classmethod
3840
def from_request(
@@ -49,29 +51,39 @@ def from_request(
4951
block_ids=block_ids,
5052
num_computed_tokens=request.num_computed_tokens,
5153
lora_request=request.lora_request,
54+
prompt_embeds=request.prompt_embeds,
5255
)
5356

54-
def __repr__(self):
57+
def __repr__(self) -> str:
58+
prompt_embeds_shape = (self.prompt_embeds.shape
59+
if self.prompt_embeds else None)
5560
return (f"NewRequestData("
5661
f"req_id={self.req_id},"
5762
f"prompt_token_ids={self.prompt_token_ids},"
5863
f"mm_features={self.mm_features},"
5964
f"sampling_params={self.sampling_params},"
6065
f"block_ids={self.block_ids},"
6166
f"num_computed_tokens={self.num_computed_tokens},"
62-
f"lora_request={self.lora_request}"
67+
f"lora_request={self.lora_request},"
68+
f"prompt_embeds_shape={prompt_embeds_shape}"
6369
")")
6470

6571
# Version of __repr__ with the prompt data obfuscated
66-
def anon_repr(self):
72+
def anon_repr(self) -> str:
73+
prompt_token_ids_len = len(
74+
self.prompt_token_ids
75+
) if self.prompt_token_ids is not None else None
76+
prompt_embeds_shape = (self.prompt_embeds.shape
77+
if self.prompt_embeds else None)
6778
return (f"NewRequestData("
6879
f"req_id={self.req_id},"
69-
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
80+
f"prompt_token_ids_len={prompt_token_ids_len},"
7081
f"mm_features={self.mm_features},"
7182
f"sampling_params={self.sampling_params},"
7283
f"block_ids={self.block_ids},"
7384
f"num_computed_tokens={self.num_computed_tokens},"
74-
f"lora_request={self.lora_request}"
85+
f"lora_request={self.lora_request},"
86+
f"prompt_embeds_shape={prompt_embeds_shape}"
7587
")")
7688

7789

vllm/v1/engine/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class EngineCoreRequest(
4747
gc=False): # type: ignore[call-arg]
4848

4949
request_id: str
50-
prompt_token_ids: list[int]
50+
prompt_token_ids: Optional[list[int]]
5151
mm_features: Optional[list[MultiModalFeatureSpec]]
5252
sampling_params: Optional[SamplingParams]
5353
pooling_params: Optional[PoolingParams]
@@ -56,6 +56,7 @@ class EngineCoreRequest(
5656
lora_request: Optional[LoRARequest]
5757
cache_salt: Optional[str]
5858
data_parallel_rank: Optional[int]
59+
prompt_embeds: Optional[torch.Tensor] = None
5960

6061
# Index of the client, used to ensure outputs are sent back to the same
6162
# client for this request when scaling out the front-end.

vllm/v1/engine/detokenizer.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.logger import init_logger
1414
from vllm.transformers_utils.detokenizer_utils import (
1515
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
16+
from vllm.utils import length_from_prompt_token_ids_or_embeds
1617
from vllm.v1.engine import EngineCoreRequest
1718

1819
logger = init_logger(__name__)
@@ -179,11 +180,12 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast,
179180
self.tokenizer: Tokenizer = tokenizer._tokenizer
180181

181182
# Find a safe place to start.
182-
prompt_suffix = request.prompt_token_ids
183+
prompt_token_ids = request.prompt_token_ids or []
184+
prompt_suffix = prompt_token_ids
183185
prompt_len = len(prompt_suffix)
184186
if prompt_len > 4:
185187
for i in range(4, min(prompt_len + 1, 24)):
186-
suffix = request.prompt_token_ids[-i:]
188+
suffix = prompt_token_ids[-i:]
187189
if '�' not in self.tokenizer.decode(suffix):
188190
prompt_suffix = suffix
189191
break
@@ -260,16 +262,25 @@ def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
260262
params = request.sampling_params
261263
assert params is not None
262264

265+
self.prompt_len = length_from_prompt_token_ids_or_embeds(
266+
request.prompt_token_ids, request.prompt_embeds)
267+
263268
# Metadata for incremental detokenization.
264-
self.tokens, self.prefix_offset, self.read_offset = (
265-
convert_prompt_ids_to_tokens(
266-
tokenizer=tokenizer,
267-
prompt_ids=request.prompt_token_ids,
268-
skip_special_tokens=params.skip_special_tokens,
269-
))
269+
if request.prompt_token_ids is not None:
270+
self.tokens, self.prefix_offset, self.read_offset = (
271+
convert_prompt_ids_to_tokens(
272+
tokenizer=tokenizer,
273+
prompt_ids=request.prompt_token_ids,
274+
skip_special_tokens=params.skip_special_tokens,
275+
))
276+
else:
277+
# Prompt embedding requests cannot be detokenized, in general.
278+
self.tokens = [""] * self.prompt_len
279+
self.prefix_offset = 0
280+
self.read_offest = 0
270281

271-
self.token_ids.extend(request.prompt_token_ids)
272-
self.prompt_len = len(request.prompt_token_ids)
282+
self.token_ids.extend(request.prompt_token_ids
283+
or [0] * self.prompt_len)
273284

274285
self.skip_special_tokens = params.skip_special_tokens
275286
self.spaces_between_special_tokens = (

vllm/v1/engine/output_processor.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
1515
extract_trace_context)
1616
from vllm.transformers_utils.tokenizer import AnyTokenizer
17+
from vllm.utils import length_from_prompt_token_ids_or_embeds
1718
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
1819
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
1920
from vllm.v1.engine.logprobs import LogprobsProcessor
@@ -86,7 +87,8 @@ def __init__(
8687
lora_name: Optional[str],
8788
output_kind: RequestOutputKind,
8889
prompt: Optional[str],
89-
prompt_token_ids: list[int],
90+
prompt_token_ids: Optional[list[int]],
91+
prompt_embeds: Optional[torch.Tensor],
9092
logprobs_processor: Optional[LogprobsProcessor],
9193
detokenizer: Optional[IncrementalDetokenizer],
9294
max_tokens_param: Optional[int],
@@ -104,7 +106,9 @@ def __init__(
104106
self.output_kind = output_kind
105107
self.prompt = prompt
106108
self.prompt_token_ids = prompt_token_ids
107-
self.prompt_len = len(prompt_token_ids)
109+
self.prompt_embeds = prompt_embeds
110+
self.prompt_len = length_from_prompt_token_ids_or_embeds(
111+
self.prompt_token_ids, self.prompt_embeds)
108112
self.logprobs_processor = logprobs_processor
109113
self.detokenizer = detokenizer
110114
self.max_tokens_param = max_tokens_param
@@ -165,6 +169,7 @@ def from_new_request(
165169
output_kind=output_kind,
166170
prompt=prompt,
167171
prompt_token_ids=request.prompt_token_ids,
172+
prompt_embeds=request.prompt_embeds,
168173
logprobs_processor=logprobs_processor,
169174
detokenizer=detokenizer,
170175
max_tokens_param=max_tokens_param,
@@ -223,6 +228,8 @@ def _new_request_output(
223228
first_output = outputs[0]
224229
if isinstance(first_output, PoolingOutput):
225230
assert len(outputs) == 1
231+
# Prompt embeddings are currently not supported by pooling requests.
232+
assert self.prompt_token_ids is not None
226233
return PoolingRequestOutput(
227234
request_id=request_id,
228235
outputs=first_output,
@@ -236,10 +243,15 @@ def _new_request_output(
236243
else:
237244
prompt_logprobs = self.logprobs_processor.prompt_logprobs
238245

246+
# If prompt embeds were used, put placeholder prompt token ids
247+
prompt_token_ids = self.prompt_token_ids
248+
if prompt_token_ids is None and self.prompt_embeds is not None:
249+
prompt_token_ids = [0] * len(self.prompt_embeds)
250+
239251
return RequestOutput(
240252
request_id=request_id,
241253
prompt=self.prompt,
242-
prompt_token_ids=self.prompt_token_ids,
254+
prompt_token_ids=prompt_token_ids,
243255
prompt_logprobs=prompt_logprobs,
244256
outputs=cast(list[CompletionOutput], outputs),
245257
finished=finished,
@@ -469,6 +481,8 @@ def do_tracing(self, engine_core_output: EngineCoreOutput,
469481

470482
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
471483
trace_context = extract_trace_context(engine_core_output.trace_headers)
484+
prompt_length = length_from_prompt_token_ids_or_embeds(
485+
req_state.prompt_token_ids, req_state.prompt_embeds)
472486
with (self.tracer.start_as_current_span(
473487
"llm_request",
474488
kind=SpanKind.SERVER,
@@ -488,7 +502,7 @@ def do_tracing(self, engine_core_output: EngineCoreOutput,
488502
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
489503
queued_time)
490504
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
491-
len(req_state.prompt_token_ids))
505+
prompt_length)
492506
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
493507
metrics.num_generation_tokens)
494508
span.set_attribute(
@@ -544,7 +558,8 @@ def _update_stats_from_finished(self, req_state: RequestState,
544558
assert req_state.stats is not None
545559
iteration_stats.update_from_finished_request(
546560
finish_reason=finish_reason,
547-
num_prompt_tokens=len(req_state.prompt_token_ids),
561+
num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
562+
req_state.prompt_token_ids, req_state.prompt_embeds),
548563
max_tokens_param=req_state.max_tokens_param,
549564
req_stats=req_state.stats)
550565
self.lora_states.finish_request(req_state)

0 commit comments

Comments
 (0)