Skip to content

Commit 33ab19d

Browse files
LunrEclipseAlvant
authored andcommitted
[Frontend] merge beam search implementations (vllm-project#9296)
Signed-off-by: Alvant <[email protected]>
1 parent a26a8ed commit 33ab19d

File tree

5 files changed

+145
-234
lines changed

5 files changed

+145
-234
lines changed

vllm/engine/async_llm_engine.py

Lines changed: 6 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,31 @@
77
from weakref import ReferenceType
88

99
import vllm.envs as envs
10-
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
1110
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
1211
ParallelConfig, SchedulerConfig)
1312
from vllm.core.scheduler import SchedulerOutputs
1413
from vllm.engine.arg_utils import AsyncEngineArgs
1514
from vllm.engine.async_timeout import asyncio_timeout
1615
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
1716
from vllm.engine.metrics_types import StatLoggerBase
17+
from vllm.engine.protocol import EngineClient
1818
from vllm.executor.executor_base import ExecutorAsyncBase
1919
from vllm.executor.gpu_executor import GPUExecutorAsync
2020
from vllm.executor.ray_utils import initialize_ray_cluster
21-
from vllm.inputs import PromptType, TokensPrompt
21+
from vllm.inputs import PromptType
2222
from vllm.logger import init_logger
2323
from vllm.lora.request import LoRARequest
2424
from vllm.model_executor.guided_decoding import (
2525
get_guided_decoding_logits_processor)
2626
from vllm.model_executor.layers.sampler import SamplerOutput
27-
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
28-
RequestOutput)
27+
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
2928
from vllm.pooling_params import PoolingParams
3029
from vllm.prompt_adapter.request import PromptAdapterRequest
31-
from vllm.sampling_params import BeamSearchParams, SamplingParams
30+
from vllm.sampling_params import SamplingParams
3231
from vllm.sequence import ExecuteModelRequest
3332
from vllm.transformers_utils.tokenizer import AnyTokenizer
3433
from vllm.usage.usage_lib import UsageContext
35-
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
36-
random_uuid, weak_bind)
34+
from vllm.utils import deprecate_kwargs, weak_bind
3735

3836
logger = init_logger(__name__)
3937
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -583,7 +581,7 @@ async def build_guided_decoding_logits_processor_async(
583581
return sampling_params
584582

585583

586-
class AsyncLLMEngine:
584+
class AsyncLLMEngine(EngineClient):
587585
"""An asynchronous wrapper for :class:`LLMEngine`.
588586
589587
This class is used to wrap the :class:`LLMEngine` class to make it
@@ -1081,102 +1079,6 @@ async def generate(
10811079
):
10821080
yield LLMEngine.validate_output(output, RequestOutput)
10831081

1084-
async def beam_search(
1085-
self,
1086-
prompt: Union[PromptType, List[int]],
1087-
request_id: str,
1088-
params: BeamSearchParams,
1089-
) -> AsyncGenerator[RequestOutput, None]:
1090-
1091-
beam_width = params.beam_width
1092-
max_tokens = params.max_tokens
1093-
ignore_eos = params.ignore_eos
1094-
temperature = params.temperature
1095-
length_penalty = params.length_penalty
1096-
1097-
tokenizer = await self.get_tokenizer()
1098-
tokenizedPrompt = prompt if isinstance(
1099-
prompt, list) else tokenizer.encode(prompt)
1100-
tokenizedLength = len(tokenizedPrompt)
1101-
1102-
sort_beams_key = create_sort_beams_key_function(
1103-
tokenizer.eos_token_id, length_penalty)
1104-
1105-
beam_search_params = SamplingParams(logprobs=2 * beam_width,
1106-
max_tokens=1,
1107-
temperature=temperature)
1108-
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
1109-
completed = []
1110-
1111-
for _ in range(max_tokens):
1112-
prompts_batch = [
1113-
TokensPrompt(prompt_token_ids=beam.tokens)
1114-
for beam in all_beams
1115-
]
1116-
1117-
tasks = []
1118-
1119-
request_id = f"beam_search-{random_uuid()}"
1120-
for i, individual_prompt in enumerate(prompts_batch):
1121-
request_id_item = f"{request_id}-{i}"
1122-
task = asyncio.create_task(
1123-
collect_from_async_generator(
1124-
self.generate(individual_prompt, beam_search_params,
1125-
request_id_item)))
1126-
tasks.append(task)
1127-
1128-
output = await asyncio.gather(*tasks)
1129-
1130-
output = [x[0] for x in output]
1131-
1132-
logger.info(output)
1133-
1134-
new_beams = []
1135-
for i, current_beam in enumerate(all_beams):
1136-
result = output[i]
1137-
1138-
if result.outputs[0].logprobs is not None:
1139-
logprobs = result.outputs[0].logprobs[0]
1140-
for token_id, logprob_obj in logprobs.items():
1141-
new_beam = BeamSearchSequence(
1142-
tokens=current_beam.tokens + [token_id],
1143-
cum_logprob=current_beam.cum_logprob +
1144-
logprob_obj.logprob)
1145-
1146-
if token_id == tokenizer.eos_token_id and \
1147-
not ignore_eos:
1148-
completed.append(new_beam)
1149-
else:
1150-
new_beams.append(new_beam)
1151-
1152-
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
1153-
all_beams = sorted_beams[:beam_width]
1154-
1155-
completed.extend(all_beams)
1156-
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
1157-
best_beams = sorted_completed[:beam_width]
1158-
1159-
for beam in best_beams:
1160-
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
1161-
1162-
beam_search_output = RequestOutput(
1163-
request_id=request_id,
1164-
prompt=prompt,
1165-
outputs=[
1166-
CompletionOutput(
1167-
text=beam.text,
1168-
cumulative_logprob=beam.cum_logprob,
1169-
token_ids=beam.tokens,
1170-
index=i,
1171-
logprobs=beam.cum_logprob,
1172-
) for (i, beam) in enumerate(best_beams)
1173-
],
1174-
finished=True,
1175-
prompt_token_ids=tokenizedPrompt,
1176-
prompt_logprobs=None)
1177-
1178-
yield LLMEngine.validate_output(beam_search_output, RequestOutput)
1179-
11801082
async def encode(
11811083
self,
11821084
prompt: PromptType,

vllm/engine/multiprocessing/client.py

Lines changed: 17 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from zmq.asyncio import Socket
1313

1414
from vllm import PoolingParams
15-
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
1615
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
16+
from vllm.core.scheduler import SchedulerOutputs
1717
from vllm.engine.arg_utils import AsyncEngineArgs
1818
# yapf conflicts with isort for this block
1919
# yapf: disable
@@ -26,18 +26,18 @@
2626
RPCError, RPCProcessRequest,
2727
RPCStartupRequest, RPCStartupResponse,
2828
RPCUProfileRequest)
29+
from vllm.engine.protocol import EngineClient
2930
# yapf: enable
3031
from vllm.envs import VLLM_RPC_TIMEOUT
31-
from vllm.inputs import PromptType, TokensPrompt
32+
from vllm.inputs import PromptType
3233
from vllm.logger import init_logger
3334
from vllm.lora.request import LoRARequest
34-
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
35-
RequestOutput)
35+
from vllm.model_executor.layers.sampler import SamplerOutput
36+
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
3637
from vllm.prompt_adapter.request import PromptAdapterRequest
37-
from vllm.sampling_params import BeamSearchParams, SamplingParams
38+
from vllm.sampling_params import SamplingParams
3839
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
39-
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
40-
random_uuid)
40+
from vllm.utils import deprecate_kwargs
4141

4242
logger = init_logger(__name__)
4343

@@ -53,7 +53,7 @@ class MQClientClosedError(Exception):
5353
"""
5454

5555

56-
class MQLLMEngineClient:
56+
class MQLLMEngineClient(EngineClient):
5757
"""A client wrapper for MQLLMEngine that conforms to the
5858
EngineClient protocol.
5959
@@ -316,7 +316,7 @@ async def _check_success(error_message: str, socket: Socket):
316316
or response != VLLM_RPC_SUCCESS_STR):
317317
raise ValueError(error_message)
318318

319-
async def get_tokenizer(self, lora_request: LoRARequest):
319+
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
320320
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
321321

322322
async def get_decoding_config(self) -> DecodingConfig:
@@ -344,8 +344,14 @@ async def abort(self, request_id: str):
344344
await self._send_one_way_rpc_request(
345345
request=RPCAbortRequest(request_id), socket=self.input_socket)
346346

347-
async def do_log_stats(self):
348-
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
347+
async def do_log_stats(
348+
self,
349+
scheduler_outputs: Optional[SchedulerOutputs] = None,
350+
model_output: Optional[List[SamplerOutput]] = None,
351+
) -> None:
352+
"""
353+
Ignore do_log_stats (handled on MQLLMEngine polling)
354+
"""
349355
pass
350356

351357
async def check_health(self):
@@ -444,104 +450,6 @@ def generate(
444450
lora_request, trace_headers,
445451
prompt_adapter_request, priority)
446452

447-
async def beam_search(
448-
self,
449-
prompt: Union[PromptType, List[int]],
450-
request_id: str,
451-
params: BeamSearchParams,
452-
) -> AsyncGenerator[RequestOutput, None]:
453-
454-
beam_width = params.beam_width
455-
max_tokens = params.max_tokens
456-
ignore_eos = params.ignore_eos
457-
temperature = params.temperature
458-
length_penalty = params.length_penalty
459-
460-
tokenizer = await self.get_tokenizer(lora_request=None)
461-
tokenizedPrompt = prompt if isinstance(
462-
prompt, list) else tokenizer.encode(prompt)
463-
tokenizedLength = len(tokenizedPrompt)
464-
465-
sort_beams_key = create_sort_beams_key_function(
466-
tokenizer.eos_token_id, length_penalty)
467-
468-
beam_search_params = SamplingParams(logprobs=2 * beam_width,
469-
max_tokens=1,
470-
temperature=temperature)
471-
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
472-
completed = []
473-
474-
for _ in range(max_tokens):
475-
prompts_batch = [
476-
TokensPrompt(prompt_token_ids=beam.tokens)
477-
for beam in all_beams
478-
]
479-
480-
tasks = []
481-
482-
request_id = f"beam_search-{random_uuid()}"
483-
for i, individual_prompt in enumerate(prompts_batch):
484-
request_id_item = f"{request_id}-{i}"
485-
task = asyncio.create_task(
486-
collect_from_async_generator(
487-
self.generate(individual_prompt, beam_search_params,
488-
request_id_item)))
489-
tasks.append(task)
490-
491-
output = await asyncio.gather(*tasks)
492-
493-
output = [x[0] for x in output]
494-
495-
logger.info(output)
496-
497-
new_beams = []
498-
for i, current_beam in enumerate(all_beams):
499-
result = output[i]
500-
501-
if result.outputs[0].logprobs is not None:
502-
logprobs = result.outputs[0].logprobs[0]
503-
for token_id, logprob_obj in logprobs.items():
504-
new_beam = BeamSearchSequence(
505-
tokens=current_beam.tokens + [token_id],
506-
cum_logprob=current_beam.cum_logprob +
507-
logprob_obj.logprob)
508-
509-
if token_id == tokenizer.eos_token_id and \
510-
not ignore_eos:
511-
completed.append(new_beam)
512-
else:
513-
new_beams.append(new_beam)
514-
515-
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
516-
all_beams = sorted_beams[:beam_width]
517-
518-
completed.extend(all_beams)
519-
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
520-
best_beams = sorted_completed[:beam_width]
521-
522-
for beam in best_beams:
523-
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
524-
525-
beam_search_output = RequestOutput(
526-
request_id=request_id,
527-
prompt=prompt,
528-
outputs=[
529-
CompletionOutput(
530-
text=beam.text,
531-
cumulative_logprob=beam.cum_logprob,
532-
token_ids=beam.tokens,
533-
index=i,
534-
logprobs=beam.cum_logprob,
535-
) for (i, beam) in enumerate(best_beams)
536-
],
537-
finished=True,
538-
prompt_token_ids=tokenizedPrompt,
539-
prompt_logprobs=None)
540-
541-
logger.info(beam_search_output)
542-
543-
yield beam_search_output
544-
545453
@overload # DEPRECATED
546454
def encode(
547455
self,

0 commit comments

Comments
 (0)