File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -214,6 +214,7 @@ steps:
214214 - vllm/model_executor/layers
215215 - vllm/sampling_metadata.py
216216 - tests/samplers
217+ - tests/conftest.py
217218 commands :
218219 - pytest -v -s samplers
219220 - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
Original file line number Diff line number Diff line change 2828 init_distributed_environment ,
2929 initialize_model_parallel )
3030from vllm .inputs import (ExplicitEncoderDecoderPrompt , TextPrompt ,
31- to_enc_dec_tuple_list , zip_enc_dec_prompts )
31+ TokensPrompt , to_enc_dec_tuple_list ,
32+ zip_enc_dec_prompts )
3233from vllm .logger import init_logger
3334from vllm .outputs import RequestOutput
3435from vllm .sampling_params import BeamSearchParams
3536from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , cuda_device_count_stateless ,
36- identity )
37+ identity , is_list_of )
3738
3839logger = init_logger (__name__ )
3940
@@ -886,6 +887,12 @@ def generate_beam_search(
886887 beam_width : int ,
887888 max_tokens : int ,
888889 ) -> List [Tuple [List [List [int ]], List [str ]]]:
890+ if is_list_of (prompts , str , check = "all" ):
891+ prompts = [TextPrompt (prompt = prompt ) for prompt in prompts ]
892+ else :
893+ prompts = [
894+ TokensPrompt (prompt_token_ids = tokens ) for tokens in prompts
895+ ]
889896 outputs = self .model .beam_search (
890897 prompts ,
891898 BeamSearchParams (beam_width = beam_width , max_tokens = max_tokens ))
You can’t perform that action at this time.
0 commit comments