Skip to content

Commit 01bc62c

Browse files
DarkLight1337garg-amit
authored andcommitted
[Bugfix] Fix missing task for speculative decoding (vllm-project#9524)
Signed-off-by: Amit Garg <[email protected]>
1 parent 198a021 commit 01bc62c

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

vllm/config.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
3434
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
3535

36-
Task = Literal["generate", "embedding"]
37-
TaskOption = Literal["auto", Task]
36+
TaskOption = Literal["auto", "generate", "embedding"]
37+
38+
# "draft" is only used internally for speculative decoding
39+
_Task = Literal["generate", "embedding", "draft"]
3840

3941

4042
class ModelConfig:
@@ -115,7 +117,7 @@ class ModelConfig:
115117

116118
def __init__(self,
117119
model: str,
118-
task: TaskOption,
120+
task: Union[TaskOption, _Task],
119121
tokenizer: str,
120122
tokenizer_mode: str,
121123
trust_remote_code: bool,
@@ -255,18 +257,21 @@ def _verify_tokenizer_mode(self) -> None:
255257

256258
def _resolve_task(
257259
self,
258-
task_option: TaskOption,
260+
task_option: Union[TaskOption, _Task],
259261
hf_config: PretrainedConfig,
260-
) -> Tuple[Set[Task], Task]:
262+
) -> Tuple[Set[_Task], _Task]:
263+
if task_option == "draft":
264+
return {"draft"}, "draft"
265+
261266
architectures = getattr(hf_config, "architectures", [])
262267

263-
task_support: Dict[Task, bool] = {
268+
task_support: Dict[_Task, bool] = {
264269
# NOTE: Listed from highest to lowest priority,
265270
# in case the model supports multiple of them
266271
"generate": ModelRegistry.is_text_generation_model(architectures),
267272
"embedding": ModelRegistry.is_embedding_model(architectures),
268273
}
269-
supported_tasks_lst: List[Task] = [
274+
supported_tasks_lst: List[_Task] = [
270275
task for task, is_supported in task_support.items() if is_supported
271276
]
272277
supported_tasks = set(supported_tasks_lst)
@@ -1002,7 +1007,7 @@ class SchedulerConfig:
10021007
"""
10031008

10041009
def __init__(self,
1005-
task: Task,
1010+
task: _Task,
10061011
max_num_batched_tokens: Optional[int],
10071012
max_num_seqs: int,
10081013
max_model_len: int,
@@ -1269,7 +1274,7 @@ def maybe_create_spec_config(
12691274
ngram_prompt_lookup_min = 0
12701275
draft_model_config = ModelConfig(
12711276
model=speculative_model,
1272-
task=target_model_config.task,
1277+
task="draft",
12731278
tokenizer=target_model_config.tokenizer,
12741279
tokenizer_mode=target_model_config.tokenizer_mode,
12751280
trust_remote_code=target_model_config.trust_remote_code,

0 commit comments

Comments
 (0)