1414from vllm .tracing import (SpanAttributes , SpanKind , Tracer ,
1515 extract_trace_context )
1616from vllm .transformers_utils .tokenizer import AnyTokenizer
17+ from vllm .utils import length_from_prompt_token_ids_or_embeds
1718from vllm .v1 .engine import EngineCoreOutput , EngineCoreRequest , FinishReason
1819from vllm .v1 .engine .detokenizer import IncrementalDetokenizer
1920from vllm .v1 .engine .logprobs import LogprobsProcessor
@@ -86,7 +87,8 @@ def __init__(
8687 lora_name : Optional [str ],
8788 output_kind : RequestOutputKind ,
8889 prompt : Optional [str ],
89- prompt_token_ids : list [int ],
90+ prompt_token_ids : Optional [list [int ]],
91+ prompt_embeds : Optional [torch .Tensor ],
9092 logprobs_processor : Optional [LogprobsProcessor ],
9193 detokenizer : Optional [IncrementalDetokenizer ],
9294 max_tokens_param : Optional [int ],
@@ -104,7 +106,9 @@ def __init__(
104106 self .output_kind = output_kind
105107 self .prompt = prompt
106108 self .prompt_token_ids = prompt_token_ids
107- self .prompt_len = len (prompt_token_ids )
109+ self .prompt_embeds = prompt_embeds
110+ self .prompt_len = length_from_prompt_token_ids_or_embeds (
111+ self .prompt_token_ids , self .prompt_embeds )
108112 self .logprobs_processor = logprobs_processor
109113 self .detokenizer = detokenizer
110114 self .max_tokens_param = max_tokens_param
@@ -165,6 +169,7 @@ def from_new_request(
165169 output_kind = output_kind ,
166170 prompt = prompt ,
167171 prompt_token_ids = request .prompt_token_ids ,
172+ prompt_embeds = request .prompt_embeds ,
168173 logprobs_processor = logprobs_processor ,
169174 detokenizer = detokenizer ,
170175 max_tokens_param = max_tokens_param ,
@@ -223,6 +228,8 @@ def _new_request_output(
223228 first_output = outputs [0 ]
224229 if isinstance (first_output , PoolingOutput ):
225230 assert len (outputs ) == 1
231+ # Prompt embeddings are currently not supported by pooling requests.
232+ assert self .prompt_token_ids is not None
226233 return PoolingRequestOutput (
227234 request_id = request_id ,
228235 outputs = first_output ,
@@ -236,10 +243,15 @@ def _new_request_output(
236243 else :
237244 prompt_logprobs = self .logprobs_processor .prompt_logprobs
238245
246+ # If prompt embeds were used, put placeholder prompt token ids
247+ prompt_token_ids = self .prompt_token_ids
248+ if prompt_token_ids is None and self .prompt_embeds is not None :
249+ prompt_token_ids = [0 ] * len (self .prompt_embeds )
250+
239251 return RequestOutput (
240252 request_id = request_id ,
241253 prompt = self .prompt ,
242- prompt_token_ids = self . prompt_token_ids ,
254+ prompt_token_ids = prompt_token_ids ,
243255 prompt_logprobs = prompt_logprobs ,
244256 outputs = cast (list [CompletionOutput ], outputs ),
245257 finished = finished ,
@@ -469,6 +481,8 @@ def do_tracing(self, engine_core_output: EngineCoreOutput,
469481
470482 arrival_time_nano_seconds = int (req_state .stats .arrival_time * 1e9 )
471483 trace_context = extract_trace_context (engine_core_output .trace_headers )
484+ prompt_length = length_from_prompt_token_ids_or_embeds (
485+ req_state .prompt_token_ids , req_state .prompt_embeds )
472486 with (self .tracer .start_as_current_span (
473487 "llm_request" ,
474488 kind = SpanKind .SERVER ,
@@ -488,7 +502,7 @@ def do_tracing(self, engine_core_output: EngineCoreOutput,
488502 span .set_attribute (SpanAttributes .GEN_AI_LATENCY_TIME_IN_QUEUE ,
489503 queued_time )
490504 span .set_attribute (SpanAttributes .GEN_AI_USAGE_PROMPT_TOKENS ,
491- len ( req_state . prompt_token_ids ) )
505+ prompt_length )
492506 span .set_attribute (SpanAttributes .GEN_AI_USAGE_COMPLETION_TOKENS ,
493507 metrics .num_generation_tokens )
494508 span .set_attribute (
@@ -544,7 +558,8 @@ def _update_stats_from_finished(self, req_state: RequestState,
544558 assert req_state .stats is not None
545559 iteration_stats .update_from_finished_request (
546560 finish_reason = finish_reason ,
547- num_prompt_tokens = len (req_state .prompt_token_ids ),
561+ num_prompt_tokens = length_from_prompt_token_ids_or_embeds (
562+ req_state .prompt_token_ids , req_state .prompt_embeds ),
548563 max_tokens_param = req_state .max_tokens_param ,
549564 req_stats = req_state .stats )
550565 self .lora_states .finish_request (req_state )
0 commit comments