|
33 | 33 | _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 |
34 | 34 | _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 |
35 | 35 |
|
36 | | -Task = Literal["generate", "embedding"] |
37 | | -TaskOption = Literal["auto", Task] |
| 36 | +TaskOption = Literal["auto", "generate", "embedding"] |
| 37 | + |
| 38 | +# "draft" is only used internally for speculative decoding |
| 39 | +_Task = Literal["generate", "embedding", "draft"] |
38 | 40 |
|
39 | 41 |
|
40 | 42 | class ModelConfig: |
@@ -115,7 +117,7 @@ class ModelConfig: |
115 | 117 |
|
116 | 118 | def __init__(self, |
117 | 119 | model: str, |
118 | | - task: TaskOption, |
| 120 | + task: Union[TaskOption, _Task], |
119 | 121 | tokenizer: str, |
120 | 122 | tokenizer_mode: str, |
121 | 123 | trust_remote_code: bool, |
@@ -255,18 +257,21 @@ def _verify_tokenizer_mode(self) -> None: |
255 | 257 |
|
256 | 258 | def _resolve_task( |
257 | 259 | self, |
258 | | - task_option: TaskOption, |
| 260 | + task_option: Union[TaskOption, _Task], |
259 | 261 | hf_config: PretrainedConfig, |
260 | | - ) -> Tuple[Set[Task], Task]: |
| 262 | + ) -> Tuple[Set[_Task], _Task]: |
| 263 | + if task_option == "draft": |
| 264 | + return {"draft"}, "draft" |
| 265 | + |
261 | 266 | architectures = getattr(hf_config, "architectures", []) |
262 | 267 |
|
263 | | - task_support: Dict[Task, bool] = { |
| 268 | + task_support: Dict[_Task, bool] = { |
264 | 269 | # NOTE: Listed from highest to lowest priority, |
265 | 270 | # in case the model supports multiple of them |
266 | 271 | "generate": ModelRegistry.is_text_generation_model(architectures), |
267 | 272 | "embedding": ModelRegistry.is_embedding_model(architectures), |
268 | 273 | } |
269 | | - supported_tasks_lst: List[Task] = [ |
| 274 | + supported_tasks_lst: List[_Task] = [ |
270 | 275 | task for task, is_supported in task_support.items() if is_supported |
271 | 276 | ] |
272 | 277 | supported_tasks = set(supported_tasks_lst) |
@@ -1002,7 +1007,7 @@ class SchedulerConfig: |
1002 | 1007 | """ |
1003 | 1008 |
|
1004 | 1009 | def __init__(self, |
1005 | | - task: Task, |
| 1010 | + task: _Task, |
1006 | 1011 | max_num_batched_tokens: Optional[int], |
1007 | 1012 | max_num_seqs: int, |
1008 | 1013 | max_model_len: int, |
@@ -1269,7 +1274,7 @@ def maybe_create_spec_config( |
1269 | 1274 | ngram_prompt_lookup_min = 0 |
1270 | 1275 | draft_model_config = ModelConfig( |
1271 | 1276 | model=speculative_model, |
1272 | | - task=target_model_config.task, |
| 1277 | + task="draft", |
1273 | 1278 | tokenizer=target_model_config.tokenizer, |
1274 | 1279 | tokenizer_mode=target_model_config.tokenizer_mode, |
1275 | 1280 | trust_remote_code=target_model_config.trust_remote_code, |
|
0 commit comments