Skip to content

Commit 6b294a0

Browse files
youkaichaorasmith
authored andcommitted
[ci] Fix sampler tests (vllm-project#11922)
Signed-off-by: youkaichao <[email protected]>
1 parent 83358de commit 6b294a0

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ steps:
214214
- vllm/model_executor/layers
215215
- vllm/sampling_metadata.py
216216
- tests/samplers
217+
- tests/conftest.py
217218
commands:
218219
- pytest -v -s samplers
219220
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers

tests/conftest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@
2828
init_distributed_environment,
2929
initialize_model_parallel)
3030
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
31-
to_enc_dec_tuple_list, zip_enc_dec_prompts)
31+
TokensPrompt, to_enc_dec_tuple_list,
32+
zip_enc_dec_prompts)
3233
from vllm.logger import init_logger
3334
from vllm.outputs import RequestOutput
3435
from vllm.sampling_params import BeamSearchParams
3536
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
36-
identity)
37+
identity, is_list_of)
3738

3839
logger = init_logger(__name__)
3940

@@ -886,6 +887,12 @@ def generate_beam_search(
886887
beam_width: int,
887888
max_tokens: int,
888889
) -> List[Tuple[List[List[int]], List[str]]]:
890+
if is_list_of(prompts, str, check="all"):
891+
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
892+
else:
893+
prompts = [
894+
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
895+
]
889896
outputs = self.model.beam_search(
890897
prompts,
891898
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))

0 commit comments

Comments
 (0)