1212from zmq .asyncio import Socket
1313
1414from vllm import PoolingParams
15- from vllm .beam_search import BeamSearchSequence , create_sort_beams_key_function
1615from vllm .config import DecodingConfig , EngineConfig , ModelConfig
16+ from vllm .core .scheduler import SchedulerOutputs
1717from vllm .engine .arg_utils import AsyncEngineArgs
1818# yapf conflicts with isort for this block
1919# yapf: disable
2626 RPCError , RPCProcessRequest ,
2727 RPCStartupRequest , RPCStartupResponse ,
2828 RPCUProfileRequest )
29+ from vllm .engine .protocol import EngineClient
2930# yapf: enable
3031from vllm .envs import VLLM_RPC_TIMEOUT
31- from vllm .inputs import PromptType , TokensPrompt
32+ from vllm .inputs import PromptType
3233from vllm .logger import init_logger
3334from vllm .lora .request import LoRARequest
34- from vllm .outputs import ( CompletionOutput , EmbeddingRequestOutput ,
35- RequestOutput )
35+ from vllm .model_executor . layers . sampler import SamplerOutput
36+ from vllm . outputs import EmbeddingRequestOutput , RequestOutput
3637from vllm .prompt_adapter .request import PromptAdapterRequest
37- from vllm .sampling_params import BeamSearchParams , SamplingParams
38+ from vllm .sampling_params import SamplingParams
3839from vllm .transformers_utils .tokenizer_group import init_tokenizer_from_configs
39- from vllm .utils import (collect_from_async_generator , deprecate_kwargs ,
40- random_uuid )
40+ from vllm .utils import deprecate_kwargs
4141
4242logger = init_logger (__name__ )
4343
@@ -53,7 +53,7 @@ class MQClientClosedError(Exception):
5353 """
5454
5555
56- class MQLLMEngineClient :
56+ class MQLLMEngineClient ( EngineClient ) :
5757 """A client wrapper for MQLLMEngine that conforms to the
5858 EngineClient protocol.
5959
@@ -316,7 +316,7 @@ async def _check_success(error_message: str, socket: Socket):
316316 or response != VLLM_RPC_SUCCESS_STR ):
317317 raise ValueError (error_message )
318318
319- async def get_tokenizer (self , lora_request : LoRARequest ):
319+ async def get_tokenizer (self , lora_request : Optional [ LoRARequest ] = None ):
320320 return await self .tokenizer .get_lora_tokenizer_async (lora_request )
321321
322322 async def get_decoding_config (self ) -> DecodingConfig :
@@ -344,8 +344,14 @@ async def abort(self, request_id: str):
344344 await self ._send_one_way_rpc_request (
345345 request = RPCAbortRequest (request_id ), socket = self .input_socket )
346346
347- async def do_log_stats (self ):
348- """Ignore do_log_stats (handled on MQLLMEngine polling)"""
347+ async def do_log_stats (
348+ self ,
349+ scheduler_outputs : Optional [SchedulerOutputs ] = None ,
350+ model_output : Optional [List [SamplerOutput ]] = None ,
351+ ) -> None :
352+ """
353+ Ignore do_log_stats (handled on MQLLMEngine polling)
354+ """
349355 pass
350356
351357 async def check_health (self ):
@@ -444,104 +450,6 @@ def generate(
444450 lora_request , trace_headers ,
445451 prompt_adapter_request , priority )
446452
447- async def beam_search (
448- self ,
449- prompt : Union [PromptType , List [int ]],
450- request_id : str ,
451- params : BeamSearchParams ,
452- ) -> AsyncGenerator [RequestOutput , None ]:
453-
454- beam_width = params .beam_width
455- max_tokens = params .max_tokens
456- ignore_eos = params .ignore_eos
457- temperature = params .temperature
458- length_penalty = params .length_penalty
459-
460- tokenizer = await self .get_tokenizer (lora_request = None )
461- tokenizedPrompt = prompt if isinstance (
462- prompt , list ) else tokenizer .encode (prompt )
463- tokenizedLength = len (tokenizedPrompt )
464-
465- sort_beams_key = create_sort_beams_key_function (
466- tokenizer .eos_token_id , length_penalty )
467-
468- beam_search_params = SamplingParams (logprobs = 2 * beam_width ,
469- max_tokens = 1 ,
470- temperature = temperature )
471- all_beams = [BeamSearchSequence (tokens = tokenizedPrompt , cum_logprob = 0 )]
472- completed = []
473-
474- for _ in range (max_tokens ):
475- prompts_batch = [
476- TokensPrompt (prompt_token_ids = beam .tokens )
477- for beam in all_beams
478- ]
479-
480- tasks = []
481-
482- request_id = f"beam_search-{ random_uuid ()} "
483- for i , individual_prompt in enumerate (prompts_batch ):
484- request_id_item = f"{ request_id } -{ i } "
485- task = asyncio .create_task (
486- collect_from_async_generator (
487- self .generate (individual_prompt , beam_search_params ,
488- request_id_item )))
489- tasks .append (task )
490-
491- output = await asyncio .gather (* tasks )
492-
493- output = [x [0 ] for x in output ]
494-
495- logger .info (output )
496-
497- new_beams = []
498- for i , current_beam in enumerate (all_beams ):
499- result = output [i ]
500-
501- if result .outputs [0 ].logprobs is not None :
502- logprobs = result .outputs [0 ].logprobs [0 ]
503- for token_id , logprob_obj in logprobs .items ():
504- new_beam = BeamSearchSequence (
505- tokens = current_beam .tokens + [token_id ],
506- cum_logprob = current_beam .cum_logprob +
507- logprob_obj .logprob )
508-
509- if token_id == tokenizer .eos_token_id and \
510- not ignore_eos :
511- completed .append (new_beam )
512- else :
513- new_beams .append (new_beam )
514-
515- sorted_beams = sorted (new_beams , key = sort_beams_key , reverse = True )
516- all_beams = sorted_beams [:beam_width ]
517-
518- completed .extend (all_beams )
519- sorted_completed = sorted (completed , key = sort_beams_key , reverse = True )
520- best_beams = sorted_completed [:beam_width ]
521-
522- for beam in best_beams :
523- beam .text = tokenizer .decode (beam .tokens [tokenizedLength :])
524-
525- beam_search_output = RequestOutput (
526- request_id = request_id ,
527- prompt = prompt ,
528- outputs = [
529- CompletionOutput (
530- text = beam .text ,
531- cumulative_logprob = beam .cum_logprob ,
532- token_ids = beam .tokens ,
533- index = i ,
534- logprobs = beam .cum_logprob ,
535- ) for (i , beam ) in enumerate (best_beams )
536- ],
537- finished = True ,
538- prompt_token_ids = tokenizedPrompt ,
539- prompt_logprobs = None )
540-
541- logger .info (beam_search_output )
542-
543- yield beam_search_output
544-
545453 @overload # DEPRECATED
546454 def encode (
547455 self ,
0 commit comments